publish master branch snapshot, revision 9df5eb1f84e13a35720a918f88324561222ab114
[platform/upstream/dldt.git] / ngraph / test / pattern.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <iostream>
20 #include <list>
21 #include <memory>
22
23 #include "gtest/gtest.h"
24 #include "ngraph/file_util.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/log.hpp"
27 #include "ngraph/ngraph.hpp"
28 #include "ngraph/op/add.hpp"
29 #include "ngraph/op/batch_norm.hpp"
30 #include "ngraph/op/constant.hpp"
31 #include "ngraph/op/divide.hpp"
32 #include "ngraph/op/multiply.hpp"
33 #include "ngraph/op/sqrt.hpp"
34 #include "ngraph/op/subtract.hpp"
35 #include "ngraph/op/sum.hpp"
36 #include "ngraph/op/sum.hpp"
37 #include "ngraph/pass/graph_rewrite.hpp"
38 #include "ngraph/pass/manager.hpp"
39 #include "ngraph/pattern/matcher.hpp"
40 #include "ngraph/pattern/op/branch.hpp"
41 #include "ngraph/pattern/op/label.hpp"
42 #include "ngraph/pattern/op/or.hpp"
43 #include "ngraph/pattern/op/skip.hpp"
44 #include "ngraph/pattern/op/true.hpp"
45 #include "ngraph/serializer.hpp"
46 #include "util/matcher.hpp"
47 #include "util/test_tools.hpp"
48
49 using namespace ngraph;
50 using namespace std;
51
52 static std::shared_ptr<Node> construct_constant_node(int n)
53 {
54     return op::Constant::create(element::i32, Shape{}, {n});
55 }
56
57 static std::shared_ptr<pattern::op::Label> construct_variance_graph()
58 {
59     // construct varaiance
60     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
61     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
62     auto input_sq = std::make_shared<op::Multiply>(input, input);
63     auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
64     auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
65     auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
66     auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
67     auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
68     auto variance = std::make_shared<op::Divide>(xmu, N);
69     auto variance_label =
70         std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
71
72     return variance_label;
73 }
74
75 static std::shared_ptr<pattern::op::Label> construct_mean_graph()
76 {
77     // construct mean;
78     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
79     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
80     auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
81     auto mean = std::make_shared<op::Divide>(sum_input1, N);
82     auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
83     return mean_label;
84 }
85
86 class TestGraphRewrite : public ngraph::pass::GraphRewrite
87 {
88 public:
89     void construct_multiply_by_one()
90     {
91         // pattern #1 : a * 1 = a
92         auto iconst1 = construct_constant_node(1);
93         auto pattern = std::make_shared<pattern::op::Label>(iconst1);
94
95         auto callback = [pattern](pattern::Matcher& m) {
96             NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
97                          << m.get_match_root()->get_name();
98             NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
99
100             auto pattern_map = m.get_pattern_map();
101
102             size_t const_node_index =
103                 m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
104             auto const_node =
105                 as_type_ptr<op::Constant>(m.get_match_root()->get_arguments().at(const_node_index));
106             auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
107             NGRAPH_DEBUG << "second_node = " << second_node->get_name()
108                          << " , pattern = " << pattern_map[pattern]->get_name();
109
110             if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
111                 pattern_map[pattern]->get_shape() != const_node->get_shape())
112             {
113                 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
114                 return false;
115             }
116
117             auto const_values = const_node->get_vector<int32_t>();
118             bool all_ones =
119                 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });
120
121             if (!all_ones)
122             {
123                 NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
124                 return false;
125             }
126
127             ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
128             return true;
129         };
130
131         auto m = make_shared<TestMatcher>(pattern * iconst1);
132         this->add_matcher(m, callback);
133     }
134
135     void construct_add_zero()
136     {
137         // pattern #2 : a + 0 = a
138         auto iconst0 = construct_constant_node(0);
139         auto pattern = std::make_shared<pattern::op::Label>(iconst0);
140
141         auto callback = [pattern](pattern::Matcher& m) {
142             NGRAPH_DEBUG << "In a callback for construct_add_zero against "
143                          << m.get_match_root()->get_name();
144             NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
145
146             auto pattern_map = m.get_pattern_map();
147
148             size_t const_node_index =
149                 m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
150             auto const_node =
151                 as_type_ptr<op::Constant>(m.get_match_root()->get_arguments().at(const_node_index));
152             auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
153             NGRAPH_DEBUG << "second_node = " << second_node->get_name()
154                          << " , pattern = " << pattern_map[pattern]->get_name();
155
156             if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
157                 pattern_map[pattern]->get_shape() != const_node->get_shape())
158             {
159                 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
160                 return false;
161             }
162
163             auto const_values = const_node->get_vector<int>();
164             bool all_zeros =
165                 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });
166
167             if (!all_zeros)
168             {
169                 NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
170                 return false;
171             }
172
173             ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
174             return true;
175         };
176
177         auto add = pattern + iconst0;
178         auto m = make_shared<TestMatcher>(add);
179         this->add_matcher(m, callback);
180     }
181
182     TestGraphRewrite()
183         : GraphRewrite()
184     {
185         construct_multiply_by_one();
186         construct_add_zero();
187     }
188 };
189
190 static void run_passes(pass::Manager& pass_manager,
191                        shared_ptr<Node> graph,
192                        std::vector<shared_ptr<op::Parameter>> parms)
193 {
194     auto func = make_shared<Function>(graph, ParameterVector{parms});
195     pass_manager.run_passes(func);
196 }
197
198 TEST(pattern, graph_rewrite)
199 {
200     Shape shape{};
201     pass::Manager pass_manager;
202     pass_manager.register_pass<TestGraphRewrite>();
203
204     {
205         auto a = make_shared<op::Parameter>(element::i32, shape);
206         auto b = make_shared<op::Parameter>(element::i32, shape);
207         auto c = make_shared<op::Parameter>(element::i32, shape);
208         auto iconst0 = construct_constant_node(0);
209         auto graph_a = a + iconst0;
210         auto graph_b = b + iconst0;
211
212         auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
213                                             ParameterVector{a, b, c});
214         pass_manager.run_passes(f);
215
216         ASSERT_TRUE(graph_a->get_output_target_inputs(0).empty());
217         ASSERT_TRUE(graph_b->get_output_target_inputs(0).empty());
218
219         auto expected = ngraph::NodeVector{a, b, a, c, b};
220         ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
221     }
222
223     {
224         auto a = make_shared<op::Parameter>(element::i32, shape);
225         auto b = make_shared<op::Parameter>(element::i32, shape);
226         auto iconst0 = construct_constant_node(0);
227         auto sum = (a + iconst0);
228         auto graph = b + sum;
229         run_passes(pass_manager, graph, {a, b});
230         ASSERT_EQ(graph->get_arguments().at(1), a);
231         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
232         ASSERT_TRUE(sum->output(0)
233                         .get_target_inputs()
234                         .empty()); // graph's input is removed from sum's target inptus
235         ASSERT_TRUE(a->get_output_target_inputs(0).count(
236             graph->input(1))); // a's output feeds into graph's input
237     }
238
239     {
240         auto a = make_shared<op::Parameter>(element::i32, shape);
241         auto b = make_shared<op::Parameter>(element::i32, shape);
242         auto iconst1 = construct_constant_node(1);
243         auto mul = (a * iconst1);
244         auto graph = b + mul;
245         run_passes(pass_manager, graph, {a, b});
246         ASSERT_EQ(graph->get_arguments().at(1), a);
247         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
248         ASSERT_TRUE(mul->output(0)
249                         .get_target_inputs()
250                         .empty()); // graph's input is removed from sum's target inputs
251         ASSERT_TRUE(a->get_output_target_inputs(0).count(
252             graph->input(1))); // a's output feeds into graph's input
253     }
254
255     {
256         auto a = make_shared<op::Parameter>(element::i32, shape);
257         auto b = make_shared<op::Parameter>(element::i32, shape);
258         auto iconst1 = construct_constant_node(1);
259         auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
260         run_passes(pass_manager, graph, {a, b});
261         ASSERT_EQ(graph->get_arguments().at(0), a);
262         ASSERT_EQ(graph->input_value(0), a->output(0)); // graph's input points to a's output
263         ASSERT_TRUE(a->get_output_target_inputs(0).count(
264             graph->input(0))); // a's output feeds into graph's input
265     }
266
267     {
268         auto a = make_shared<op::Parameter>(element::i32, shape);
269         auto b = make_shared<op::Parameter>(element::i32, shape);
270         auto iconst0 = construct_constant_node(0);
271         auto iconst1 = construct_constant_node(1);
272         auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
273         run_passes(pass_manager, graph, {a, b});
274         ASSERT_EQ(graph->get_arguments().at(1), a);
275         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
276         ASSERT_TRUE(a->get_output_target_inputs(0).count(
277             graph->input(1))); // a's output feeds into graph's input
278     }
279
280     {
281         auto a = make_shared<op::Parameter>(element::i32, shape);
282         auto b = make_shared<op::Parameter>(element::i32, shape);
283         auto iconst1 = construct_constant_node(1);
284         auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
285         run_passes(pass_manager, graph, {a, b});
286         ASSERT_EQ(graph->get_arguments().at(1), a);
287         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
288         ASSERT_TRUE(a->get_output_target_inputs(0).count(
289             graph->input(1))); // a's output feeds into graph's input
290     }
291 }
292
293 TEST(pattern, matcher)
294 {
295     Shape shape{};
296     auto a = make_shared<op::Parameter>(element::i32, shape);
297     TestMatcher n;
298     ASSERT_TRUE(n.match(a, a));
299     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
300
301     auto abs = make_shared<op::Abs>(a);
302     auto any = std::make_shared<pattern::op::Skip>(a);
303     ASSERT_TRUE(n.match(any, abs));
304     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
305
306     auto false_pred = [](std::shared_ptr<Node> /* no */) { return false; };
307     auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
308     ASSERT_TRUE(n.match(any_false, a));
309     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
310
311     auto pattern = std::make_shared<pattern::op::Label>(a);
312     ASSERT_TRUE(n.match(pattern, a));
313     ASSERT_EQ(n.get_pattern_map()[pattern], a);
314     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
315
316     auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
317     ASSERT_FALSE(n.match(pattern_false, a));
318     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
319
320     auto b = make_shared<op::Parameter>(element::i32, shape);
321
322     auto is_bea = [](std::shared_ptr<Node> node) -> bool {
323         return node->is_binary_elementwise_arithmetic();
324     };
325     auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
326     auto add_ab = a + b;
327     ASSERT_TRUE(n.match(bea, add_ab));
328     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
329     ASSERT_TRUE(n.match(bea, b + a));
330
331     auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
332     ASSERT_FALSE(n.match(bea_false, a + b));
333
334     auto add_abs_b = abs + b;
335     auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
336     ASSERT_TRUE(n.match(bea_any_of, add_abs_b));
337
338     auto add_b_abs = b + abs;
339     ASSERT_TRUE(n.match(bea_any_of, add_b_abs));
340
341     auto bea_any_of_label =
342         std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
343     ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
344     ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);
345
346     auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
347     auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
348     ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
349     ASSERT_EQ(n.get_pattern_map()[abs_label], abs);
350
351     auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
352     auto ab = a + b;
353     ASSERT_TRUE(n.match(bea_label, ab));
354     ASSERT_EQ(n.get_pattern_map()[bea_label], ab);
355
356     auto d = make_shared<op::Parameter>(element::i32, shape);
357     ASSERT_FALSE(n.match(d, b));
358
359     ASSERT_FALSE(n.match(abs + b, b + b));
360     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
361
362     auto add_absb = abs + b;
363     ASSERT_TRUE(n.match(any + b, add_absb));
364     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
365
366     ASSERT_TRUE(n.match(pattern + b, add_absb));
367     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
368     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
369
370     ASSERT_TRUE(n.match(b + pattern, add_absb));
371     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
372     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
373
374     auto c = make_shared<op::Parameter>(element::i32, shape);
375     auto mul_add_absb = c * (add_absb);
376     ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
377     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
378     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
379
380     ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
381     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
382     ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
383     auto mul_c_add_ab = c * add_ab;
384     ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b)));  // nested any
385     ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
386     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
387
388     auto iconst1_0 = construct_constant_node(1);
389     auto iconst1_1 = construct_constant_node(1);
390     ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
391     ASSERT_EQ(n.get_pattern_map()[pattern], a);
392     auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
393     auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
394     ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
395
396     // Subgraph labels
397     auto add = a + b;
398     auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
399     ASSERT_TRUE(n.match(label, add));
400     ASSERT_EQ(n.get_pattern_map()[label], add);
401     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
402
403     ASSERT_FALSE(n.match(label, a - b));
404
405     ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
406     ASSERT_EQ(n.get_pattern_map()[label], add);
407
408     // Correct argument order
409     ASSERT_FALSE(n.match(b - a, a - b));
410     auto aab = a * (a - b);
411     auto paab = pattern * (pattern - b);
412     ASSERT_TRUE(n.match(paab, aab));
413     auto aba = a * (b - a);
414     ASSERT_FALSE(n.match(paab, aba));
415     auto paba = pattern * (b - pattern);
416     ASSERT_FALSE(n.match(paba, aab));
417
418     // Correlations
419     auto label1 = std::make_shared<pattern::op::Label>(a);
420     auto tmp = label1 + b;
421     auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
422     auto sub_label1 = label1 - label2;
423     auto sub_add = a - add;
424     ASSERT_TRUE(n.match(sub_label1, sub_add));
425     ASSERT_EQ(n.get_pattern_map()[label1], a);
426     ASSERT_EQ(n.get_pattern_map()[label2], add);
427     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
428
429     ASSERT_FALSE(n.match(sub_label1, add - a));
430
431     auto add_label1 = label1 + label2;
432     ASSERT_TRUE(n.match(add_label1, add + a));
433     ASSERT_EQ(n.get_pattern_map()[label1], a);
434     ASSERT_EQ(n.get_pattern_map()[label2], add);
435
436     // Or
437     ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
438     ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));
439
440     // Branch
441     {
442         auto branch = std::make_shared<pattern::op::Branch>();
443         auto star = std::make_shared<pattern::op::Or>(
444             OutputVector{branch, std::make_shared<pattern::op::True>()});
445         auto pattern = star + star;
446         branch->set_destination(pattern);
447         ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
448         ASSERT_EQ(n.get_matched_nodes().size(), 4);
449     }
450
451     // strict mode
452     {
453         TestMatcher sm(Output<Node>{}, "TestMatcher", true);
454         // exact shape and type
455         auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
456         auto label_dynamic_shape =
457             make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
458         auto param = make_shared<op::Parameter>(element::f32, Shape{});
459         ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
460         // wrong type
461         auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
462         ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
463         // dynamic dimension
464         auto label_dynamic_dimension =
465             make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
466         auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
467         ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
468         // dynamic type
469         auto label_dynamic_type =
470             make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
471         ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
472     }
473 }
474
475 TEST(pattern, mean)
476 {
477     // construct mean
478     TestMatcher n;
479
480     auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
481     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
482     auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
483     auto mean = std::make_shared<op::Divide>(sum_input1, N);
484
485     auto mean_graph = construct_mean_graph();
486     ASSERT_TRUE(n.match(mean_graph, mean));
487     ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
488 }
489
490 TEST(pattern, variance)
491 {
492     // construct variance
493     TestMatcher n;
494     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
495     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
496     auto input_sq = std::make_shared<op::Multiply>(input, input);
497     auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
498     auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
499     auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
500     auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
501     auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
502     auto variance = std::make_shared<op::Divide>(xmu, N);
503
504     auto var_graph = construct_variance_graph();
505     ASSERT_TRUE(n.match(var_graph, variance));
506     ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
507 }
508
509 TEST(pattern, previous_matches)
510 {
511     using ngraph::pattern::Matcher;
512     Shape shape{};
513     Matcher::PatternMap previous_matches;
514     auto a = make_shared<op::Parameter>(element::i32, shape);
515     auto b = make_shared<op::Parameter>(element::i32, shape);
516     auto pattern = std::make_shared<pattern::op::Label>(b);
517     auto abs = make_shared<op::Abs>(a);
518     auto add = abs + b;
519     {
520         Matcher n(pattern + b);
521         ASSERT_TRUE(n.match(add, previous_matches));
522         ASSERT_EQ(n.get_pattern_map()[pattern], abs);
523     }
524
525     {
526         Matcher n(pattern + b);
527         previous_matches.insert(std::make_pair(pattern, a));
528         ASSERT_FALSE(n.match(add, previous_matches));
529     }
530 }
531
532 TEST(pattern, test_sort)
533 {
534     using ngraph::pattern::Matcher;
535     Shape shape{};
536
537     auto a = make_shared<op::Parameter>(element::i32, shape);
538     auto b = make_shared<op::Parameter>(element::i32, shape);
539     auto abs1 = make_shared<op::Abs>(a);
540     auto abs2 = make_shared<op::Abs>(b);
541     auto add = abs1 + abs2;
542
543     auto pa = make_shared<op::Parameter>(element::i32, shape);
544     auto pb = make_shared<op::Parameter>(element::i32, shape);
545     auto pabs1 = make_shared<op::Abs>(pa);
546     auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
547     auto pabs2 = make_shared<op::Abs>(b);
548     auto padd = pabs1_label + pabs2;
549
550     {
551         Matcher n1(padd);
552         ASSERT_TRUE(n1.match(add));
553         auto r1 = n1.get_pattern_map()[pabs1_label];
554         ASSERT_TRUE(n1.match(add));
555         ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
556     }
557 }
558
559 TEST(pattern, recurrent_pattern)
560 {
561     using ngraph::pattern::RecurrentMatcher;
562     Shape shape{};
563     ngraph::pattern::Matcher::PatternMap previous_matches;
564     auto a = make_shared<op::Parameter>(element::i32, shape);
565     auto b = make_shared<op::Parameter>(element::i32, shape);
566     auto rpattern = std::make_shared<pattern::op::Label>(b);
567     auto iconst0 = construct_constant_node(0);
568     auto abs = make_shared<op::Abs>(a);
569     auto add1 = iconst0 + b;
570     auto add2 = iconst0 + add1;
571     auto add3 = iconst0 + add2;
572     auto padd = iconst0 + rpattern;
573     std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
574     RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
575     ASSERT_TRUE(rm.match(add3));
576     ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
577     auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
578     ASSERT_EQ(recurrent_matches.at(0), add2);
579     ASSERT_EQ(recurrent_matches.at(1), add1);
580     ASSERT_EQ(recurrent_matches.at(2), b);
581
582     // Multiple labels in a reccuring pattern
583     auto iconst1 = construct_constant_node(1);
584     auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
585     auto add2_2 = iconst1 + add1;
586     auto add3_2 = iconst0 + add2_2;
587     auto padd2 = iconst_label + rpattern;
588     RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
589     ASSERT_TRUE(rm2.match(add3_2));
590     ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
591     recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
592     ASSERT_EQ(recurrent_matches.at(0), add2_2);
593     ASSERT_EQ(recurrent_matches.at(1), add1);
594     ASSERT_EQ(recurrent_matches.at(2), b);
595     auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
596     ASSERT_EQ(iconst_matches.at(0), iconst0);
597     ASSERT_EQ(iconst_matches.at(1), iconst1);
598     ASSERT_EQ(iconst_matches.at(2), iconst0);
599
600     // Non-matching correlated labels
601     std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
602     correlated_matches.insert(iconst_label);
603     RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
604     ASSERT_TRUE(rm3.match(add3_2));
605     ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
606     iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
607     ASSERT_EQ(iconst_matches.size(), 1);
608     ASSERT_EQ(iconst_matches.at(0), iconst0);
609
610     // Matching correlated labels and
611     // testing if RecurrentMatcher can be reused for different nodes
612     ASSERT_TRUE(rm3.match(add3));
613     ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
614     recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
615     ASSERT_EQ(recurrent_matches.at(0), add2);
616     ASSERT_EQ(recurrent_matches.at(1), add1);
617     ASSERT_EQ(recurrent_matches.at(2), b);
618     iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
619     ASSERT_EQ(iconst_matches.at(0), iconst0);
620     ASSERT_EQ(iconst_matches.at(1), iconst0);
621     ASSERT_EQ(iconst_matches.at(2), iconst0);
622 }
623
624 class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
625 {
626 public:
627     void construct_recurrent_add()
628     {
629         Shape shape{};
630         auto iconst0 = construct_constant_node(0);
631         auto iconst_label =
632             std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
633         auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
634         auto padd = iconst_label + rpattern;
635
636         auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
637             NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
638                          << rm.get_match_root()->get_name();
639
640             auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
641
642             auto is_iconst_zero = [](std::shared_ptr<Node> n) {
643                 bool result = ngraph::is_zero(n);
644                 NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
645                 return ngraph::is_zero(n);
646             };
647
648             bool are_all_iconst_zeros =
649                 std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);
650
651             if (!are_all_iconst_zeros)
652             {
653                 return false;
654             }
655
656             auto number_of_adds = rm.get_number_of_recurrent_matches();
657             // replace the topmost add with the seed (i.e. the first parameter to add)
658             // matches are added in reverse order (i.e. the first match is the topmost node)
659             auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
660             NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
661                          << arg->get_name();
662             ngraph::replace_node(rm.get_match_root(), arg);
663             return true;
664         };
665
666         std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
667         auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
668         this->add_matcher(rm, callback);
669     }
670
671     TestRecurrentGraphRewrite()
672         : RecurrentGraphRewrite()
673     {
674         construct_recurrent_add();
675     }
676 };
677
678 TEST(pattern, recurrent_graph_rewrite)
679 {
680     Shape shape{};
681     pass::Manager pass_manager;
682     pass_manager.register_pass<TestRecurrentGraphRewrite>();
683
684     {
685         auto a = make_shared<op::Parameter>(element::i32, shape);
686         auto iconst0 = construct_constant_node(0);
687         auto add_a1 = a + iconst0;
688         auto add_a2 = add_a1 + iconst0;
689         auto add_a3 = add_a2 + iconst0;
690         auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);
691
692         auto b = make_shared<op::Parameter>(element::i32, shape);
693         auto add_b1 = b + iconst0;
694         auto add_b2 = add_b1 + iconst0;
695         auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);
696
697         auto graph = abs_add_a3 * abs_add_b2;
698
699         auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
700         pass_manager.run_passes(f);
701
702         auto left_abs = graph->get_argument(0);
703         auto add_a = left_abs->get_argument(0);
704         ASSERT_EQ(add_a, a);
705
706         auto right_abs = graph->get_argument(1);
707         auto add_b = right_abs->get_argument(0);
708         ASSERT_EQ(add_b, b);
709     }
710 }
711
712 TEST(pattern, label_on_skip)
713 {
714     Shape shape{2, 2};
715     auto a = make_shared<op::Parameter>(element::i32, shape);
716     auto b = make_shared<op::Parameter>(element::i32, Shape{});
717     auto iconst = ngraph::make_zero(element::i32, Shape{});
718     auto label = std::make_shared<pattern::op::Label>(iconst);
719     auto const_label =
720         std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
721
722     auto bcst_pred = [](std::shared_ptr<Node> n) {
723         return as_type_ptr<op::Broadcast>(n) != nullptr;
724     };
725
726     auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
727     auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
728     auto matcher = std::make_shared<pattern::Matcher>(
729         std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
730
731     auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
732     auto mul = a * const_broadcast;
733     auto mul_scalar = b * iconst;
734     ASSERT_TRUE(matcher->match(mul));
735     ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
736     ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
737     ASSERT_EQ(matcher->get_pattern_map()[label], a);
738     ASSERT_TRUE(matcher->match(mul_scalar));
739     ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
740     ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
741     ASSERT_EQ(matcher->get_pattern_map()[label], b);
742 }
743
744 TEST(pattern, is_contained_match)
745 {
746     Shape shape{};
747     auto a = make_shared<op::Parameter>(element::i32, shape);
748     auto absn = make_shared<op::Abs>(a);
749     TestMatcher n;
750
751     auto label_a = std::make_shared<pattern::op::Label>(a);
752     auto label_abs = make_shared<op::Abs>(a);
753     ASSERT_TRUE(n.match(label_abs, absn));
754     auto result_absn = make_shared<op::Result>(absn);
755     ASSERT_TRUE(n.is_contained_match());
756
757     auto absn2 = make_shared<op::Abs>(absn);
758     auto result_absn2 = make_shared<op::Result>(absn2);
759     auto label_abs2 = make_shared<op::Abs>(label_abs);
760     ASSERT_TRUE(n.match(label_abs2, absn2));
761     ASSERT_FALSE(n.is_contained_match());
762 }