80069c38f94cf6ee55dedbb11e3877bea8ed64f4
[platform/upstream/dldt.git] / ngraph / test / pattern.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <iostream>
20 #include <list>
21 #include <memory>
22 #include <ngraph/pattern/op/wrap_type.hpp>
23
24 #include "gtest/gtest.h"
25 #include "ngraph/file_util.hpp"
26 #include "ngraph/graph_util.hpp"
27 #include "ngraph/log.hpp"
28 #include "ngraph/ngraph.hpp"
29 #include "ngraph/op/add.hpp"
30 #include "ngraph/op/batch_norm.hpp"
31 #include "ngraph/op/constant.hpp"
32 #include "ngraph/op/divide.hpp"
33 #include "ngraph/op/multiply.hpp"
34 #include "ngraph/op/sqrt.hpp"
35 #include "ngraph/op/subtract.hpp"
36 #include "ngraph/op/sum.hpp"
37 #include "ngraph/op/sum.hpp"
38 #include "ngraph/op/util/op_types.hpp"
39 #include "ngraph/pass/graph_rewrite.hpp"
40 #include "ngraph/pass/manager.hpp"
41 #include "ngraph/pattern/matcher.hpp"
42 #include "ngraph/pattern/op/branch.hpp"
43 #include "ngraph/pattern/op/label.hpp"
44 #include "ngraph/pattern/op/or.hpp"
45 #include "ngraph/pattern/op/skip.hpp"
46 #include "ngraph/pattern/op/true.hpp"
47 #include "util/matcher.hpp"
48 #include "util/test_tools.hpp"
49
50 NGRAPH_SUPPRESS_DEPRECATED_START
51
52 using namespace ngraph;
53 using namespace std;
54
55 static std::shared_ptr<Node> construct_constant_node(int n)
56 {
57     return op::Constant::create(element::i32, Shape{}, {n});
58 }
59
60 static std::shared_ptr<pattern::op::Label> construct_variance_graph()
61 {
62     // construct varaiance
63     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
64     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
65     auto input_sq = std::make_shared<op::Multiply>(input, input);
66     auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
67     auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
68     auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
69     auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
70     auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
71     auto variance = std::make_shared<op::Divide>(xmu, N);
72     auto variance_label =
73         std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
74
75     return variance_label;
76 }
77
78 static std::shared_ptr<pattern::op::Label> construct_mean_graph()
79 {
80     // construct mean;
81     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
82     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
83     auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
84     auto mean = std::make_shared<op::Divide>(sum_input1, N);
85     auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
86     return mean_label;
87 }
88
89 class TestGraphRewrite : public ngraph::pass::GraphRewrite
90 {
91 public:
92     void construct_multiply_by_one()
93     {
94         // pattern #1 : a * 1 = a
95         auto iconst1 = construct_constant_node(1);
96         auto pattern = std::make_shared<pattern::op::Label>(iconst1);
97
98         auto callback = [pattern](pattern::Matcher& m) {
99             NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
100                          << m.get_match_root()->get_name();
101             NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
102
103             auto pattern_map = m.get_pattern_map();
104
105             size_t const_node_index =
106                 m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
107             auto const_node = as_type_ptr<op::Constant>(
108                 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
109             auto second_node =
110                 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
111             NGRAPH_DEBUG << "second_node = " << second_node->get_name()
112                          << " , pattern = " << pattern_map[pattern]->get_name();
113
114             if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
115                 pattern_map[pattern]->get_shape() != const_node->get_shape())
116             {
117                 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
118                 return false;
119             }
120
121             auto const_values = const_node->get_vector<int32_t>();
122             bool all_ones =
123                 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });
124
125             if (!all_ones)
126             {
127                 NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
128                 return false;
129             }
130
131             ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
132             return true;
133         };
134
135         auto m = make_shared<TestMatcher>(pattern * iconst1);
136         NGRAPH_SUPPRESS_DEPRECATED_START
137         this->add_matcher(m, callback);
138         NGRAPH_SUPPRESS_DEPRECATED_END
139     }
140
141     void construct_add_zero()
142     {
143         // pattern #2 : a + 0 = a
144         auto iconst0 = construct_constant_node(0);
145         auto pattern = std::make_shared<pattern::op::Label>(iconst0);
146
147         auto callback = [pattern](pattern::Matcher& m) {
148             NGRAPH_DEBUG << "In a callback for construct_add_zero against "
149                          << m.get_match_root()->get_name();
150             NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
151
152             auto pattern_map = m.get_pattern_map();
153
154             size_t const_node_index =
155                 m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
156             auto const_node = as_type_ptr<op::Constant>(
157                 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
158             auto second_node =
159                 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
160             NGRAPH_DEBUG << "second_node = " << second_node->get_name()
161                          << " , pattern = " << pattern_map[pattern]->get_name();
162
163             if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
164                 pattern_map[pattern]->get_shape() != const_node->get_shape())
165             {
166                 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
167                 return false;
168             }
169
170             auto const_values = const_node->get_vector<int>();
171             bool all_zeros =
172                 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });
173
174             if (!all_zeros)
175             {
176                 NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
177                 return false;
178             }
179
180             ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
181             return true;
182         };
183
184         auto add = pattern + iconst0;
185         auto m = make_shared<TestMatcher>(add);
186         NGRAPH_SUPPRESS_DEPRECATED_START
187         this->add_matcher(m, callback);
188         NGRAPH_SUPPRESS_DEPRECATED_END
189     }
190
191     TestGraphRewrite()
192         : GraphRewrite()
193     {
194         construct_multiply_by_one();
195         construct_add_zero();
196     }
197 };
198
199 static void run_passes(pass::Manager& pass_manager,
200                        shared_ptr<Node> graph,
201                        std::vector<shared_ptr<op::Parameter>> parms)
202 {
203     auto func = make_shared<Function>(graph, ParameterVector{parms});
204     pass_manager.run_passes(func);
205 }
206
207 TEST(pattern, graph_rewrite)
208 {
209     Shape shape{};
210     pass::Manager pass_manager;
211     pass_manager.register_pass<TestGraphRewrite>();
212
213     {
214         auto a = make_shared<op::Parameter>(element::i32, shape);
215         auto b = make_shared<op::Parameter>(element::i32, shape);
216         auto c = make_shared<op::Parameter>(element::i32, shape);
217         auto iconst0 = construct_constant_node(0);
218         auto graph_a = a + iconst0;
219         auto graph_b = b + iconst0;
220
221         auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
222                                             ParameterVector{a, b, c});
223         pass_manager.run_passes(f);
224
225         ASSERT_TRUE(graph_a->get_output_target_inputs(0).empty());
226         ASSERT_TRUE(graph_b->get_output_target_inputs(0).empty());
227
228         auto expected = ngraph::NodeVector{a, b, a, c, b};
229         ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
230     }
231
232     {
233         auto a = make_shared<op::Parameter>(element::i32, shape);
234         auto b = make_shared<op::Parameter>(element::i32, shape);
235         auto iconst0 = construct_constant_node(0);
236         auto sum = (a + iconst0);
237         auto graph = b + sum;
238         run_passes(pass_manager, graph, {a, b});
239         ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
240         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
241         ASSERT_TRUE(sum->output(0)
242                         .get_target_inputs()
243                         .empty()); // graph's input is removed from sum's target inptus
244         ASSERT_TRUE(a->get_output_target_inputs(0).count(
245             graph->input(1))); // a's output feeds into graph's input
246     }
247
248     {
249         auto a = make_shared<op::Parameter>(element::i32, shape);
250         auto b = make_shared<op::Parameter>(element::i32, shape);
251         auto iconst1 = construct_constant_node(1);
252         auto mul = (a * iconst1);
253         auto graph = b + mul;
254         run_passes(pass_manager, graph, {a, b});
255         ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
256         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
257         ASSERT_TRUE(mul->output(0)
258                         .get_target_inputs()
259                         .empty()); // graph's input is removed from sum's target inputs
260         ASSERT_TRUE(a->get_output_target_inputs(0).count(
261             graph->input(1))); // a's output feeds into graph's input
262     }
263
264     {
265         auto a = make_shared<op::Parameter>(element::i32, shape);
266         auto b = make_shared<op::Parameter>(element::i32, shape);
267         auto iconst1 = construct_constant_node(1);
268         auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
269         run_passes(pass_manager, graph, {a, b});
270         ASSERT_EQ(graph->input_value(0).get_node_shared_ptr(), a);
271         ASSERT_EQ(graph->input_value(0), a->output(0)); // graph's input points to a's output
272         ASSERT_TRUE(a->get_output_target_inputs(0).count(
273             graph->input(0))); // a's output feeds into graph's input
274     }
275
276     {
277         auto a = make_shared<op::Parameter>(element::i32, shape);
278         auto b = make_shared<op::Parameter>(element::i32, shape);
279         auto iconst0 = construct_constant_node(0);
280         auto iconst1 = construct_constant_node(1);
281         auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
282         run_passes(pass_manager, graph, {a, b});
283         ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
284         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
285         ASSERT_TRUE(a->get_output_target_inputs(0).count(
286             graph->input(1))); // a's output feeds into graph's input
287     }
288
289     {
290         auto a = make_shared<op::Parameter>(element::i32, shape);
291         auto b = make_shared<op::Parameter>(element::i32, shape);
292         auto iconst1 = construct_constant_node(1);
293         auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
294         run_passes(pass_manager, graph, {a, b});
295         ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
296         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
297         ASSERT_TRUE(a->get_output_target_inputs(0).count(
298             graph->input(1))); // a's output feeds into graph's input
299     }
300 }
301
302 TEST(pattern, matcher)
303 {
304     Shape shape{};
305     auto a = make_shared<op::Parameter>(element::i32, shape);
306     TestMatcher n;
307     ASSERT_TRUE(n.match(a, a));
308     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
309
310     auto abs = make_shared<op::Abs>(a);
311     auto any = std::make_shared<pattern::op::Skip>(a);
312     ASSERT_TRUE(n.match(any, abs));
313     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
314
315     auto false_pred = [](std::shared_ptr<Node> /* no */) { return false; };
316     auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
317     ASSERT_TRUE(n.match(any_false, a));
318     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
319
320     auto pattern = std::make_shared<pattern::op::Label>(a);
321     ASSERT_TRUE(n.match(pattern, a));
322     ASSERT_EQ(n.get_pattern_map()[pattern], a);
323     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
324
325     auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
326     ASSERT_FALSE(n.match(pattern_false, a));
327     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
328
329     auto b = make_shared<op::Parameter>(element::i32, shape);
330
331     auto is_bea = [](std::shared_ptr<Node> node) -> bool {
332         return op::is_binary_elementwise_arithmetic(node);
333     };
334     auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
335     auto add_ab = a + b;
336     ASSERT_TRUE(n.match(bea, add_ab));
337     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
338     ASSERT_TRUE(n.match(bea, b + a));
339
340     auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
341     ASSERT_FALSE(n.match(bea_false, a + b));
342
343     auto add_abs_b = abs + b;
344     auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
345     ASSERT_TRUE(n.match(bea_any_of, add_abs_b));
346
347     auto add_b_abs = b + abs;
348     ASSERT_TRUE(n.match(bea_any_of, add_b_abs));
349
350     auto bea_any_of_label =
351         std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
352     ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
353     ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);
354
355     auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
356     auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
357     ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
358     ASSERT_EQ(n.get_pattern_map()[abs_label], abs);
359
360     auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
361     auto ab = a + b;
362     ASSERT_TRUE(n.match(bea_label, ab));
363     ASSERT_EQ(n.get_pattern_map()[bea_label], ab);
364
365     auto d = make_shared<op::Parameter>(element::i32, shape);
366     ASSERT_FALSE(n.match(d, b));
367
368     ASSERT_FALSE(n.match(abs + b, b + b));
369     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
370
371     auto add_absb = abs + b;
372     ASSERT_TRUE(n.match(any + b, add_absb));
373     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
374
375     ASSERT_TRUE(n.match(pattern + b, add_absb));
376     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
377     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
378
379     ASSERT_TRUE(n.match(b + pattern, add_absb));
380     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
381     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
382
383     auto c = make_shared<op::Parameter>(element::i32, shape);
384     auto mul_add_absb = c * (add_absb);
385     ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
386     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
387     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
388
389     ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
390     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
391     ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
392     auto mul_c_add_ab = c * add_ab;
393     ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b)));  // nested any
394     ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
395     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
396
397     auto iconst1_0 = construct_constant_node(1);
398     auto iconst1_1 = construct_constant_node(1);
399     ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
400     ASSERT_EQ(n.get_pattern_map()[pattern], a);
401     auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
402     auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
403     ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
404
405     // Subgraph labels
406     auto add = a + b;
407     auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
408     ASSERT_TRUE(n.match(label, add));
409     ASSERT_EQ(n.get_pattern_map()[label], add);
410     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
411
412     ASSERT_FALSE(n.match(label, a - b));
413
414     ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
415     ASSERT_EQ(n.get_pattern_map()[label], add);
416
417     // Correct argument order
418     ASSERT_FALSE(n.match(b - a, a - b));
419     auto aab = a * (a - b);
420     auto paab = pattern * (pattern - b);
421     ASSERT_TRUE(n.match(paab, aab));
422     auto aba = a * (b - a);
423     ASSERT_FALSE(n.match(paab, aba));
424     auto paba = pattern * (b - pattern);
425     ASSERT_FALSE(n.match(paba, aab));
426
427     // Correlations
428     auto label1 = std::make_shared<pattern::op::Label>(a);
429     auto tmp = label1 + b;
430     auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
431     auto sub_label1 = label1 - label2;
432     auto sub_add = a - add;
433     ASSERT_TRUE(n.match(sub_label1, sub_add));
434     ASSERT_EQ(n.get_pattern_map()[label1], a);
435     ASSERT_EQ(n.get_pattern_map()[label2], add);
436     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
437
438     ASSERT_FALSE(n.match(sub_label1, add - a));
439
440     auto add_label1 = label1 + label2;
441     ASSERT_TRUE(n.match(add_label1, add + a));
442     ASSERT_EQ(n.get_pattern_map()[label1], a);
443     ASSERT_EQ(n.get_pattern_map()[label2], add);
444
445     // Or
446     ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
447     ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));
448
449     // Branch
450     {
451         auto branch = std::make_shared<pattern::op::Branch>();
452         auto star = std::make_shared<pattern::op::Or>(
453             OutputVector{branch, std::make_shared<pattern::op::True>()});
454         auto pattern = star + star;
455         branch->set_destination(pattern);
456         ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
457         ASSERT_EQ(n.get_matched_nodes().size(), 4);
458     }
459
460     // strict mode
461     {
462         TestMatcher sm(Output<Node>{}, "TestMatcher", true);
463         // exact shape and type
464         auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
465         auto label_dynamic_shape =
466             make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
467         auto param = make_shared<op::Parameter>(element::f32, Shape{});
468         ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
469         // wrong type
470         auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
471         ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
472         // dynamic dimension
473         auto label_dynamic_dimension =
474             make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
475         auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
476         ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
477         // dynamic type
478         auto label_dynamic_type =
479             make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
480         ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
481     }
482 }
483
484 TEST(pattern, mean)
485 {
486     // construct mean
487     TestMatcher n;
488
489     auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
490     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
491     auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
492     auto mean = std::make_shared<op::Divide>(sum_input1, N);
493
494     auto mean_graph = construct_mean_graph();
495     ASSERT_TRUE(n.match(mean_graph, mean));
496     ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
497 }
498
499 TEST(pattern, variance)
500 {
501     // construct variance
502     TestMatcher n;
503     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
504     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
505     auto input_sq = std::make_shared<op::Multiply>(input, input);
506     auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
507     auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
508     auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
509     auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
510     auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
511     auto variance = std::make_shared<op::Divide>(xmu, N);
512
513     auto var_graph = construct_variance_graph();
514     ASSERT_TRUE(n.match(var_graph, variance));
515     ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
516 }
517
518 TEST(pattern, previous_matches)
519 {
520     using ngraph::pattern::Matcher;
521     Shape shape{};
522     Matcher::PatternMap previous_matches;
523     auto a = make_shared<op::Parameter>(element::i32, shape);
524     auto b = make_shared<op::Parameter>(element::i32, shape);
525     auto pattern = std::make_shared<pattern::op::Label>(b);
526     auto abs = make_shared<op::Abs>(a);
527     auto add = abs + b;
528     {
529         Matcher n(pattern + b);
530         ASSERT_TRUE(n.match(add, previous_matches));
531         ASSERT_EQ(n.get_pattern_map()[pattern], abs);
532     }
533
534     {
535         Matcher n(pattern + b);
536         previous_matches.insert(std::make_pair(pattern, a));
537         ASSERT_FALSE(n.match(add, previous_matches));
538     }
539 }
540
541 TEST(pattern, test_sort)
542 {
543     using ngraph::pattern::Matcher;
544     Shape shape{};
545
546     auto a = make_shared<op::Parameter>(element::i32, shape);
547     auto b = make_shared<op::Parameter>(element::i32, shape);
548     auto abs1 = make_shared<op::Abs>(a);
549     auto abs2 = make_shared<op::Abs>(b);
550     auto add = abs1 + abs2;
551
552     auto pa = make_shared<op::Parameter>(element::i32, shape);
553     auto pb = make_shared<op::Parameter>(element::i32, shape);
554     auto pabs1 = make_shared<op::Abs>(pa);
555     auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
556     auto pabs2 = make_shared<op::Abs>(b);
557     auto padd = pabs1_label + pabs2;
558
559     {
560         Matcher n1(padd);
561         ASSERT_TRUE(n1.match(add));
562         auto r1 = n1.get_pattern_map()[pabs1_label];
563         ASSERT_TRUE(n1.match(add));
564         ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
565     }
566 }
567
568 TEST(pattern, recurrent_pattern)
569 {
570     using ngraph::pattern::RecurrentMatcher;
571     Shape shape{};
572     ngraph::pattern::Matcher::PatternMap previous_matches;
573     auto a = make_shared<op::Parameter>(element::i32, shape);
574     auto b = make_shared<op::Parameter>(element::i32, shape);
575     auto rpattern = std::make_shared<pattern::op::Label>(b);
576     auto iconst0 = construct_constant_node(0);
577     auto abs = make_shared<op::Abs>(a);
578     auto add1 = iconst0 + b;
579     auto add2 = iconst0 + add1;
580     auto add3 = iconst0 + add2;
581     auto padd = iconst0 + rpattern;
582     std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
583     RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
584     ASSERT_TRUE(rm.match(add3));
585     ASSERT_EQ(rm.get_number_of_bound_labels(), 3);
586     auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
587     ASSERT_EQ(recurrent_matches.at(0), add2);
588     ASSERT_EQ(recurrent_matches.at(1), add1);
589     ASSERT_EQ(recurrent_matches.at(2), b);
590
591     // Multiple labels in a reccuring pattern
592     auto iconst1 = construct_constant_node(1);
593     auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
594     auto add2_2 = iconst1 + add1;
595     auto add3_2 = iconst0 + add2_2;
596     auto padd2 = iconst_label + rpattern;
597     RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
598     ASSERT_TRUE(rm2.match(add3_2));
599     ASSERT_EQ(rm2.get_number_of_bound_labels(), 4);
600     recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
601     ASSERT_EQ(recurrent_matches.at(0), add2_2);
602     ASSERT_EQ(recurrent_matches.at(1), add1);
603     ASSERT_EQ(recurrent_matches.at(2), b);
604     auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
605     ASSERT_EQ(iconst_matches.at(0), iconst0);
606     ASSERT_EQ(iconst_matches.at(1), iconst1);
607     ASSERT_EQ(iconst_matches.at(2), iconst0);
608
609     // Non-matching correlated labels
610     std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
611     correlated_matches.insert(iconst_label);
612     RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
613     ASSERT_TRUE(rm3.match(add3_2));
614     ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
615     iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
616     ASSERT_EQ(iconst_matches.size(), 1);
617     ASSERT_EQ(iconst_matches.at(0), iconst0);
618
619     // Matching correlated labels and
620     // testing if RecurrentMatcher can be reused for different nodes
621     ASSERT_TRUE(rm3.match(add3));
622     ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
623     recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
624     ASSERT_EQ(recurrent_matches.at(0), add2);
625     ASSERT_EQ(recurrent_matches.at(1), add1);
626     ASSERT_EQ(recurrent_matches.at(2), b);
627     iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
628     ASSERT_EQ(iconst_matches.at(0), iconst0);
629     ASSERT_EQ(iconst_matches.at(1), iconst0);
630     ASSERT_EQ(iconst_matches.at(2), iconst0);
631 }
632
633 class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
634 {
635 public:
636     void construct_recurrent_add()
637     {
638         Shape shape{};
639         auto iconst0 = construct_constant_node(0);
640         auto iconst_label =
641             std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
642         auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
643         auto padd = iconst_label + rpattern;
644
645         auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
646             NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
647                          << rm.get_match_root()->get_name();
648
649             auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
650
651             auto is_iconst_zero = [](std::shared_ptr<Node> n) {
652                 bool result = ngraph::is_zero(n);
653                 NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
654                 return ngraph::is_zero(n);
655             };
656
657             bool are_all_iconst_zeros =
658                 std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);
659
660             if (!are_all_iconst_zeros)
661             {
662                 return false;
663             }
664
665             auto number_of_adds = rm.get_number_of_recurrent_matches();
666             // replace the topmost add with the seed (i.e. the first parameter to add)
667             // matches are added in reverse order (i.e. the first match is the topmost node)
668             auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
669             NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
670                          << arg->get_name();
671             ngraph::replace_node(rm.get_match_root(), arg);
672             return true;
673         };
674
675         std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
676         auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
677         NGRAPH_SUPPRESS_DEPRECATED_START
678         this->add_matcher(rm, callback);
679         NGRAPH_SUPPRESS_DEPRECATED_END
680     }
681
682     TestRecurrentGraphRewrite()
683         : RecurrentGraphRewrite()
684     {
685         construct_recurrent_add();
686     }
687 };
688
689 TEST(pattern, recurrent_graph_rewrite)
690 {
691     Shape shape{};
692     pass::Manager pass_manager;
693     pass_manager.register_pass<TestRecurrentGraphRewrite>();
694
695     {
696         auto a = make_shared<op::Parameter>(element::i32, shape);
697         auto iconst0 = construct_constant_node(0);
698         auto add_a1 = a + iconst0;
699         auto add_a2 = add_a1 + iconst0;
700         auto add_a3 = add_a2 + iconst0;
701         auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);
702
703         auto b = make_shared<op::Parameter>(element::i32, shape);
704         auto add_b1 = b + iconst0;
705         auto add_b2 = add_b1 + iconst0;
706         auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);
707
708         auto graph = abs_add_a3 * abs_add_b2;
709
710         auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
711         pass_manager.run_passes(f);
712
713         auto left_abs = graph->input_value(0).get_node_shared_ptr();
714         auto add_a = left_abs->input_value(0).get_node_shared_ptr();
715         ASSERT_EQ(add_a, a);
716
717         auto right_abs = graph->input_value(1).get_node_shared_ptr();
718         auto add_b = right_abs->input_value(0).get_node_shared_ptr();
719         ASSERT_EQ(add_b, b);
720     }
721 }
722
723 TEST(pattern, label_on_skip)
724 {
725     Shape shape{2, 2};
726     auto a = make_shared<op::Parameter>(element::i32, shape);
727     auto b = make_shared<op::Parameter>(element::i32, Shape{});
728     auto iconst = ngraph::make_zero(element::i32, Shape{});
729     auto label = std::make_shared<pattern::op::Label>(iconst);
730     auto const_label =
731         std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
732
733     auto bcst_pred = [](std::shared_ptr<Node> n) {
734         return as_type_ptr<op::Broadcast>(n) != nullptr;
735     };
736
737     auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
738     auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
739     auto matcher = std::make_shared<pattern::Matcher>(
740         std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
741
742     auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
743     auto mul = a * const_broadcast;
744     auto mul_scalar = b * iconst;
745     ASSERT_TRUE(matcher->match(mul));
746     ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
747     ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
748     ASSERT_EQ(matcher->get_pattern_map()[label], a);
749     ASSERT_TRUE(matcher->match(mul_scalar));
750     ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
751     ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
752     ASSERT_EQ(matcher->get_pattern_map()[label], b);
753 }
754
755 TEST(pattern, is_contained_match)
756 {
757     Shape shape{};
758     auto a = make_shared<op::Parameter>(element::i32, shape);
759     auto absn = make_shared<op::Abs>(a);
760     TestMatcher n;
761
762     auto label_a = std::make_shared<pattern::op::Label>(a);
763     auto label_abs = make_shared<op::Abs>(a);
764     ASSERT_TRUE(n.match(label_abs, absn));
765     auto result_absn = make_shared<op::Result>(absn);
766     ASSERT_TRUE(n.is_contained_match());
767
768     auto absn2 = make_shared<op::Abs>(absn);
769     auto result_absn2 = make_shared<op::Result>(absn2);
770     auto label_abs2 = make_shared<op::Abs>(label_abs);
771     ASSERT_TRUE(n.match(label_abs2, absn2));
772     ASSERT_FALSE(n.is_contained_match());
773 }
774
775 TEST(pattern, wrap_type)
776 {
777     auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
778     auto b = make_shared<op::Abs>(a);
779     auto c = make_shared<op::Relu>(a);
780     auto mul1 = make_shared<op::v1::Multiply>(a, op::Constant::create(element::f32, Shape{}, {1}));
781     auto mul2 = make_shared<op::v1::Multiply>(op::Constant::create(element::f32, Shape{}, {1}), a);
782
783     {
784         auto m = pattern::wrap_type<op::Abs>();
785         auto matcher = std::make_shared<pattern::Matcher>(m, "AbsMatcher");
786         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
787         ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
788         ASSERT_EQ(matcher->get_matched_nodes()[0], b);
789         ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
790         ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
791     }
792     {
793         auto m1 = pattern::wrap_type<op::Parameter>();
794         auto m2 = pattern::wrap_type<op::Abs>({m1});
795         auto matcher = std::make_shared<pattern::Matcher>(m2, "ParamAbsMatcher");
796         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
797         ASSERT_EQ(matcher->get_matched_nodes().size(), 2);
798         ASSERT_EQ(matcher->get_pattern_map().count(m1), 1);
799         ASSERT_EQ(matcher->get_pattern_map().count(m2), 1);
800         ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
801     }
802     {
803         auto m1 = pattern::wrap_type<op::v1::Multiply>(
804             {pattern::any_input(), pattern::wrap_type<op::Constant>()});
805         auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
806         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
807         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
808     }
809     {
810         auto m1 = pattern::wrap_type<op::v1::Multiply>(
811             {pattern::wrap_type<op::Constant>(), pattern::any_input()});
812         auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
813         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
814         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
815     }
816 }