Updated Transformation development doc (#2391)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Wed, 23 Sep 2020 14:26:12 +0000 (17:26 +0300)
committerGitHub <noreply@github.com>
Wed, 23 Sep 2020 14:26:12 +0000 (17:26 +0300)
docs/IE_PLUGIN_DG/NewTransformation.md
docs/IE_PLUGIN_DG/images/graph_rewrite_efficient_search.png [new file with mode: 0644]
docs/IE_PLUGIN_DG/images/graph_rewrite_execution.png [new file with mode: 0644]
docs/IE_PLUGIN_DG/images/register_new_node.png [new file with mode: 0644]
docs/IE_PLUGIN_DG/images/transformations_structure.png [new file with mode: 0644]
docs/examples/example_ngraph_utils.cpp
docs/template_plugin/src/template_function_transformation.cpp
docs/template_plugin/src/template_function_transformation.hpp
docs/template_plugin/src/template_pattern_transformation.cpp
docs/template_plugin/src/template_pattern_transformation.hpp

index 55074e0..4f23f3a 100644 (file)
@@ -63,6 +63,8 @@ Below you can find examples how `ngraph::Function` can be created:
 nGraph has tree main transformation types: `ngraph::pass::FunctionPass` - strait forward way to work with `ngraph::Function` directly;
 `ngraph::pass::MatcherPass` - pattern based transformation approach; `ngraph::pass::GraphRewrite` - container for matcher passes.
 
+![transformations_structure]
+
 ###1. ngraph::pass::FunctionPass <a name="function_pass"></a>
 
 `ngraph::pass::FunctionPass` is used for transformations that take entire `ngraph::Function` as input and process it.
@@ -131,7 +133,7 @@ The last step is to register Matcher and callback inside MatcherPass pass. And t
 
 ```cpp
 // Register matcher and callback
-this->register_matcher(m, callback);
+register_matcher(m, callback);
 ```
 ### Matcher pass execution
 MatcherPass has multiple ways to be executed:
@@ -154,21 +156,32 @@ In addition GraphRewrite handles nodes that were registered by MatcherPasses dur
 
 > **Note**: when using `pass::Manager` temporary GraphRewrite is used to execute single MatcherPass. 
 
+GraphRewrite has two algorithms for MatcherPasses execution. First algorithm is a straight-forward. It applies each MatcherPass in registraion order to current node.
+
+![graph_rewrite_execution]
+
+But it is nor really efficient when you have a lot of registered passes. So first of all GraphRewrite check that all MatcherPass patterns has type based root node (it means that type of this node is not hidden into predicate).
+And then creates map from registered MatcherPases. That helps to avoid additional cost of applying each MatcherPass for each node.
+
+![graph_rewrite_efficient_search] 
+
 ## Pattern matching <a name="pattern_matching"></a>
 
-Sometimes patterns can't be expressed via regular nGraph operations. For example if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create pattern where some of operations can have different types.
+Sometimes patterns can't be expressed via regular nGraph operations or it is too complicated. 
+For example if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create pattern where some of operations can have different types.
 And for these cases nGraph provides additional helpers to construct patterns for GraphRewrite transformations. 
 
 There are two main helpers:
-1. `ngraph::pattern::op::Label` - helps to express inputs if their type is undefined.
-2. `ngraph::pattern::op::Any` - helps to express intermediate nodes of pattern if their type is unknown.
+1. `ngraph::pattern::any_input` - helps to express inputs if their types are undefined.
+2. `ngraph::pattern::wrap_type<T>` - helps to express nodes of pattern without specifying node attributes.
 
 Let's go through example to have better understanding how it works:
 
 > **Note**: node attributes do not participate in pattern matching and needed only for operations creation. Only operation types participate in pattern matching.
 
-Example below shows basic usage of `pattern::op::Label` class.
-Here we construct Multiply pattern with arbitrary first input and Constant as a second input.
+Example below shows basic usage of `pattern::any_input`.
+Here we construct Multiply pattern with arbitrary first input and Constant as a second input. 
+Also as Multiply is commutative operation it does not matter in which order we set inputs (any_input/Constant or Constant/any_input) because both cases will be matched.
 
 @snippet example_ngraph_utils.cpp pattern:label_example
 
@@ -176,7 +189,7 @@ This example show how we can construct pattern when operation has arbitrary numb
 
 @snippet example_ngraph_utils.cpp pattern:concat_example
 
-This example shows how to use predicate to construct pattern where operation has two different types. Also it shows how to match pattern manually on given node.
+This example shows how to use predicate to construct pattern. Also it shows how to match pattern manually on given node.
 
 @snippet example_ngraph_utils.cpp pattern:predicate_example
 
@@ -321,9 +334,11 @@ ngraph::copy_runtime_info({a, b, c}, {e, f});
 
 When transformation has multiple fusions or decompositions `ngraph::copy_runtime_info` must be called multiple times for each case. 
 
+> **Note**: copy_runtime_info removes rt_info from destination nodes. If you want to keep it you need to specify them in source nodes like this: copy_runtime_info({a, b, c}, {a, b})
+
 ###5. Constant Folding
 
-If your transformation inserts constant sub-graphs that needs to be folded do not forget to use `ngraph::pass::ConstantFolding()` after your transformation.
+If your transformation inserts constant sub-graphs that needs to be folded do not forget to use `ngraph::pass::ConstantFolding()` after your transformation or call constant folding directly for operation.
 Example below shows how constant sub-graph can be constructed.
 
 ```cpp
@@ -334,6 +349,12 @@ auto pow = std::make_shared<ngraph::opset3::Power>(
 auto mul = std::make_shared<ngraph::opset3::Multiply>(input /* not constant input */, pow);
 ``` 
 
+Manual constant folding is more preferable than `ngraph::pass::ConstantFolding()` because it is much faster.
+
+Below you can find an example of manual constant folding:
+
+@snippet src/template_pattern_transformation.cpp manual_constant_folding
+
 ## Common mistakes in transformations <a name="common_mistakes"></a>
 
 In transformation development process 
@@ -427,4 +448,8 @@ The basic transformation test looks like this:
 
 
 [ngraph_replace_node]: ../images/ngraph_replace_node.png
-[ngraph_insert_node]: ../images/ngraph_insert_node.png
\ No newline at end of file
+[ngraph_insert_node]: ../images/ngraph_insert_node.png
+[transformations_structure]: ../images/transformations_structure.png
+[register_new_node]: ../images/register_new_node.png
+[graph_rewrite_execution]: ../images/graph_rewrite_execution.png
+[graph_rewrite_efficient_search]: ../images/graph_rewrite_efficient_search.png
\ No newline at end of file
diff --git a/docs/IE_PLUGIN_DG/images/graph_rewrite_efficient_search.png b/docs/IE_PLUGIN_DG/images/graph_rewrite_efficient_search.png
new file mode 100644 (file)
index 0000000..1376398
--- /dev/null
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05eb8600d2c905975674f3a0a5dc676107d22f65f2a1f78ee1cfabc1771721ea
+size 41307
diff --git a/docs/IE_PLUGIN_DG/images/graph_rewrite_execution.png b/docs/IE_PLUGIN_DG/images/graph_rewrite_execution.png
new file mode 100644 (file)
index 0000000..17dc2d9
--- /dev/null
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17cd470c6d04d7aabbdb4a08e31f9c97eab960cf7ef5bbd3a541df92db38f26b
+size 40458
diff --git a/docs/IE_PLUGIN_DG/images/register_new_node.png b/docs/IE_PLUGIN_DG/images/register_new_node.png
new file mode 100644 (file)
index 0000000..3c34f65
--- /dev/null
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:80297287c81a2f27b7e74895738afd90844354a8dd745757e8321e2fb6ed547e
+size 31246
diff --git a/docs/IE_PLUGIN_DG/images/transformations_structure.png b/docs/IE_PLUGIN_DG/images/transformations_structure.png
new file mode 100644 (file)
index 0000000..953d667
--- /dev/null
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b206c602626f17ba5787810b9a28f9cde511448c3e63a5c7ba976cee7868bdb
+size 14907
index e46b040..1b17fc1 100644 (file)
@@ -4,6 +4,8 @@
 
 #include <memory>
 
+#include <ngraph/pattern/op/wrap_type.hpp>
+
 // ! [ngraph:include]
 #include <ngraph/ngraph.hpp>
 #include <ngraph/opsets/opset3.hpp>
@@ -89,7 +91,7 @@ ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
 // ! [pattern:label_example]
 // Detect Multiply with arbitrary first input and second as Constant
 // ngraph::pattern::op::Label - represent arbitrary input
-auto input = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{1});
+auto input = ngraph::pattern::any_input();
 auto value = ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0.5});
 auto mul = std::make_shared<ngraph::opset3::Multiply>(input, value);
 auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MultiplyMatcher");
@@ -99,20 +101,17 @@ auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MultiplyMatcher");
 {
 // ! [pattern:concat_example]
 // Detect Concat operation with arbitrary number of inputs
-auto concat = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset3::Concat>());
+auto concat = ngraph::pattern::wrap_type<ngraph::opset3::Concat>();
 auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "ConcatMatcher");
 // ! [pattern:concat_example]
 }
 
 {
 // ! [pattern:predicate_example]
-// Detect Multiply or Add operation
-auto lin_op = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{},
-              [](const std::shared_ptr<ngraph::Node> & node) -> bool {
-                    return std::dynamic_pointer_cast<ngraph::opset3::Multiply>(node) ||
-                           std::dynamic_pointer_cast<ngraph::opset3::Add>(node);
-              });
-auto m = std::make_shared<ngraph::pattern::Matcher>(lin_op, "MultiplyOrAddMatcher");
+// Detect Multiply->Add sequence where mul has exactly one consumer
+auto mul = ngraph::pattern::wrap_type<ngraph::opset3::Multiply>(ngraph::pattern::consumers_count(1)/*сheck consumers count*/);
+auto add = ngraph::pattern::wrap_type<ngraph::opset3::Add>({mul, ngraph::pattern::any_input()});
+auto m = std::make_shared<ngraph::pattern::Matcher>(add, "MultiplyAddMatcher");
 // Matcher can be used to match pattern manually on given node
 if (m->match(node->output(0))) {
     // Successfully matched
index aa2299d..a33994e 100644 (file)
@@ -10,7 +10,7 @@ using namespace ngraph;
 // template_function_transformation.cpp
 bool pass::MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
     // Example transformation code
-    std::vector<std::shared_ptr<Node> > nodes;
+    NodeVector nodes;
 
     // Traverse nGraph Function in topological order
     for (auto & node : f->get_ordered_ops()) {
index 72938be..5691e8b 100644 (file)
@@ -18,8 +18,6 @@ class MyFunctionTransformation;
 // template_function_transformation.hpp
 class ngraph::pass::MyFunctionTransformation: public ngraph::pass::FunctionPass {
 public:
-    MyFunctionTransformation() : FunctionPass() {}
-
     bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
 };
 // ! [function_pass:template_transformation_hpp]
index e8ca30c..0dc0cf6 100644 (file)
@@ -16,8 +16,8 @@ using namespace ngraph;
 // template_pattern_transformation.cpp
 ngraph::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
     // Pattern example
-    auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
-    auto input1 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
+    auto input0 = pattern::any_input();
+    auto input1 = pattern::any_input();
     auto div = std::make_shared<ngraph::opset3::Divide>(input0, input1);
 
     ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
@@ -49,7 +49,7 @@ ngraph::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
     // Register pattern with Divide operation as a pattern root node
     auto m = std::make_shared<ngraph::pattern::Matcher>(div, "ConvertDivide");
     // Register Matcher
-    this->register_matcher(m, callback);
+    register_matcher(m, callback);
 }
 // ! [graph_rewrite:template_transformation_cpp]
 
@@ -82,7 +82,7 @@ ngraph::pass::ReluReluFusionMatcher::ReluReluFusionMatcher() {
     // Register pattern with Relu operation as a pattern root node
     auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
     // Register Matcher
-    this->register_matcher(m, callback);
+    register_matcher(m, callback);
 }
 // ! [matcher_pass:relu_fusion]
 
@@ -137,3 +137,16 @@ pass.add_matcher<ngraph::pass::ReluReluFusionMatcher>();
 pass.run_on_function(f);
 // ! [matcher_pass:graph_rewrite]
 }
+
+// ! [manual_constant_folding]
+template <class T>
+Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) {
+    auto eltwise = std::make_shared<T>(input0, input1);
+    OutputVector output(eltwise->get_output_size());
+    // If constant folding wasn't successful return eltwise output
+    if (!eltwise->constant_fold(output, {input0, input1})) {
+        return eltwise->output(0);
+    }
+    return output[0];
+}
+// ! [manual_constant_folding]
index c9346ef..b51a233 100644 (file)
@@ -17,6 +17,10 @@ class ReluReluFusionMatcher;
 
 // ! [graph_rewrite:template_transformation_hpp]
 // template_pattern_transformation.hpp
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief Add transformation description.
+ */
 class ngraph::pass::DecomposeDivideMatcher: public ngraph::pass::MatcherPass {
 public:
     DecomposeDivideMatcher();