nGraph has tree main transformation types: `ngraph::pass::FunctionPass` - strait forward way to work with `ngraph::Function` directly;
`ngraph::pass::MatcherPass` - pattern based transformation approach; `ngraph::pass::GraphRewrite` - container for matcher passes.
+![transformations_structure]
+
###1. ngraph::pass::FunctionPass <a name="function_pass"></a>
`ngraph::pass::FunctionPass` is used for transformations that take entire `ngraph::Function` as input and process it.
```cpp
// Register matcher and callback
-this->register_matcher(m, callback);
+register_matcher(m, callback);
```
### Matcher pass execution
MatcherPass has multiple ways to be executed:
> **Note**: when using `pass::Manager` temporary GraphRewrite is used to execute single MatcherPass.
+GraphRewrite has two algorithms for MatcherPasses execution. First algorithm is a straight-forward. It applies each MatcherPass in registraion order to current node.
+
+![graph_rewrite_execution]
+
+But it is nor really efficient when you have a lot of registered passes. So first of all GraphRewrite check that all MatcherPass patterns has type based root node (it means that type of this node is not hidden into predicate).
+And then creates map from registered MatcherPases. That helps to avoid additional cost of applying each MatcherPass for each node.
+
+![graph_rewrite_efficient_search]
+
## Pattern matching <a name="pattern_matching"></a>
-Sometimes patterns can't be expressed via regular nGraph operations. For example if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create pattern where some of operations can have different types.
+Sometimes patterns can't be expressed via regular nGraph operations or it is too complicated.
+For example if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create pattern where some of operations can have different types.
And for these cases nGraph provides additional helpers to construct patterns for GraphRewrite transformations.
There are two main helpers:
-1. `ngraph::pattern::op::Label` - helps to express inputs if their type is undefined.
-2. `ngraph::pattern::op::Any` - helps to express intermediate nodes of pattern if their type is unknown.
+1. `ngraph::pattern::any_input` - helps to express inputs if their types are undefined.
+2. `ngraph::pattern::wrap_type<T>` - helps to express nodes of pattern without specifying node attributes.
Let's go through example to have better understanding how it works:
> **Note**: node attributes do not participate in pattern matching and needed only for operations creation. Only operation types participate in pattern matching.
-Example below shows basic usage of `pattern::op::Label` class.
-Here we construct Multiply pattern with arbitrary first input and Constant as a second input.
+Example below shows basic usage of `pattern::any_input`.
+Here we construct Multiply pattern with arbitrary first input and Constant as a second input.
+Also as Multiply is commutative operation it does not matter in which order we set inputs (any_input/Constant or Constant/any_input) because both cases will be matched.
@snippet example_ngraph_utils.cpp pattern:label_example
@snippet example_ngraph_utils.cpp pattern:concat_example
-This example shows how to use predicate to construct pattern where operation has two different types. Also it shows how to match pattern manually on given node.
+This example shows how to use predicate to construct pattern. Also it shows how to match pattern manually on given node.
@snippet example_ngraph_utils.cpp pattern:predicate_example
When transformation has multiple fusions or decompositions `ngraph::copy_runtime_info` must be called multiple times for each case.
+> **Note**: copy_runtime_info removes rt_info from destination nodes. If you want to keep it you need to specify them in source nodes like this: copy_runtime_info({a, b, c}, {a, b})
+
###5. Constant Folding
-If your transformation inserts constant sub-graphs that needs to be folded do not forget to use `ngraph::pass::ConstantFolding()` after your transformation.
+If your transformation inserts constant sub-graphs that needs to be folded do not forget to use `ngraph::pass::ConstantFolding()` after your transformation or call constant folding directly for operation.
Example below shows how constant sub-graph can be constructed.
```cpp
auto mul = std::make_shared<ngraph::opset3::Multiply>(input /* not constant input */, pow);
```
+Manual constant folding is more preferable than `ngraph::pass::ConstantFolding()` because it is much faster.
+
+Below you can find an example of manual constant folding:
+
+@snippet src/template_pattern_transformation.cpp manual_constant_folding
+
## Common mistakes in transformations <a name="common_mistakes"></a>
In transformation development process
[ngraph_replace_node]: ../images/ngraph_replace_node.png
-[ngraph_insert_node]: ../images/ngraph_insert_node.png
\ No newline at end of file
+[ngraph_insert_node]: ../images/ngraph_insert_node.png
+[transformations_structure]: ../images/transformations_structure.png
+[register_new_node]: ../images/register_new_node.png
+[graph_rewrite_execution]: ../images/graph_rewrite_execution.png
+[graph_rewrite_efficient_search]: ../images/graph_rewrite_efficient_search.png
\ No newline at end of file
--- /dev/null
+version https://git-lfs.github.com/spec/v1
+oid sha256:05eb8600d2c905975674f3a0a5dc676107d22f65f2a1f78ee1cfabc1771721ea
+size 41307
--- /dev/null
+version https://git-lfs.github.com/spec/v1
+oid sha256:17cd470c6d04d7aabbdb4a08e31f9c97eab960cf7ef5bbd3a541df92db38f26b
+size 40458
--- /dev/null
+version https://git-lfs.github.com/spec/v1
+oid sha256:80297287c81a2f27b7e74895738afd90844354a8dd745757e8321e2fb6ed547e
+size 31246
--- /dev/null
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b206c602626f17ba5787810b9a28f9cde511448c3e63a5c7ba976cee7868bdb
+size 14907
#include <memory>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
// ! [ngraph:include]
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset3.hpp>
// ! [pattern:label_example]
// Detect Multiply with arbitrary first input and second as Constant
// ngraph::pattern::op::Label - represent arbitrary input
-auto input = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{1});
+auto input = ngraph::pattern::any_input();
auto value = ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0.5});
auto mul = std::make_shared<ngraph::opset3::Multiply>(input, value);
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MultiplyMatcher");
{
// ! [pattern:concat_example]
// Detect Concat operation with arbitrary number of inputs
-auto concat = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset3::Concat>());
+auto concat = ngraph::pattern::wrap_type<ngraph::opset3::Concat>();
auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "ConcatMatcher");
// ! [pattern:concat_example]
}
{
// ! [pattern:predicate_example]
-// Detect Multiply or Add operation
-auto lin_op = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{},
- [](const std::shared_ptr<ngraph::Node> & node) -> bool {
- return std::dynamic_pointer_cast<ngraph::opset3::Multiply>(node) ||
- std::dynamic_pointer_cast<ngraph::opset3::Add>(node);
- });
-auto m = std::make_shared<ngraph::pattern::Matcher>(lin_op, "MultiplyOrAddMatcher");
+// Detect Multiply->Add sequence where mul has exactly one consumer
+auto mul = ngraph::pattern::wrap_type<ngraph::opset3::Multiply>(ngraph::pattern::consumers_count(1)/*сheck consumers count*/);
+auto add = ngraph::pattern::wrap_type<ngraph::opset3::Add>({mul, ngraph::pattern::any_input()});
+auto m = std::make_shared<ngraph::pattern::Matcher>(add, "MultiplyAddMatcher");
// Matcher can be used to match pattern manually on given node
if (m->match(node->output(0))) {
// Successfully matched
// template_function_transformation.cpp
bool pass::MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
// Example transformation code
- std::vector<std::shared_ptr<Node> > nodes;
+ NodeVector nodes;
// Traverse nGraph Function in topological order
for (auto & node : f->get_ordered_ops()) {
// template_function_transformation.hpp
class ngraph::pass::MyFunctionTransformation: public ngraph::pass::FunctionPass {
public:
- MyFunctionTransformation() : FunctionPass() {}
-
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
// ! [function_pass:template_transformation_hpp]
// template_pattern_transformation.cpp
ngraph::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
// Pattern example
- auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
- auto input1 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
+ auto input0 = pattern::any_input();
+ auto input1 = pattern::any_input();
auto div = std::make_shared<ngraph::opset3::Divide>(input0, input1);
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
// Register pattern with Divide operation as a pattern root node
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "ConvertDivide");
// Register Matcher
- this->register_matcher(m, callback);
+ register_matcher(m, callback);
}
// ! [graph_rewrite:template_transformation_cpp]
// Register pattern with Relu operation as a pattern root node
auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
// Register Matcher
- this->register_matcher(m, callback);
+ register_matcher(m, callback);
}
// ! [matcher_pass:relu_fusion]
pass.run_on_function(f);
// ! [matcher_pass:graph_rewrite]
}
+
+// ! [manual_constant_folding]
+template <class T>
+Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) {
+ auto eltwise = std::make_shared<T>(input0, input1);
+ OutputVector output(eltwise->get_output_size());
+ // If constant folding wasn't successful return eltwise output
+ if (!eltwise->constant_fold(output, {input0, input1})) {
+ return eltwise->output(0);
+ }
+ return output[0];
+}
+// ! [manual_constant_folding]
// ! [graph_rewrite:template_transformation_hpp]
// template_pattern_transformation.hpp
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief Add transformation description.
+ */
class ngraph::pass::DecomposeDivideMatcher: public ngraph::pass::MatcherPass {
public:
DecomposeDivideMatcher();