nGraph Transformations refactoring (#931)
[platform/upstream/dldt.git] / ngraph / test / pattern.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <iostream>
20 #include <list>
21 #include <memory>
22 #include <ngraph/pattern/op/wrap_type.hpp>
23
24 #include "gtest/gtest.h"
25 #include "ngraph/file_util.hpp"
26 #include "ngraph/graph_util.hpp"
27 #include "ngraph/log.hpp"
28 #include "ngraph/ngraph.hpp"
29 #include "ngraph/op/add.hpp"
30 #include "ngraph/op/batch_norm.hpp"
31 #include "ngraph/op/constant.hpp"
32 #include "ngraph/op/divide.hpp"
33 #include "ngraph/op/multiply.hpp"
34 #include "ngraph/op/sqrt.hpp"
35 #include "ngraph/op/subtract.hpp"
36 #include "ngraph/op/sum.hpp"
37 #include "ngraph/op/sum.hpp"
38 #include "ngraph/op/util/op_types.hpp"
39 #include "ngraph/pass/graph_rewrite.hpp"
40 #include "ngraph/pass/manager.hpp"
41 #include "ngraph/pattern/matcher.hpp"
42 #include "ngraph/pattern/op/branch.hpp"
43 #include "ngraph/pattern/op/label.hpp"
44 #include "ngraph/pattern/op/or.hpp"
45 #include "ngraph/pattern/op/skip.hpp"
46 #include "ngraph/pattern/op/true.hpp"
47 #include "ngraph/serializer.hpp"
48 #include "util/matcher.hpp"
49 #include "util/test_tools.hpp"
50
51 using namespace ngraph;
52 using namespace std;
53
54 static std::shared_ptr<Node> construct_constant_node(int n)
55 {
56     return op::Constant::create(element::i32, Shape{}, {n});
57 }
58
59 static std::shared_ptr<pattern::op::Label> construct_variance_graph()
60 {
61     // construct varaiance
62     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
63     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
64     auto input_sq = std::make_shared<op::Multiply>(input, input);
65     auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
66     auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
67     auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
68     auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
69     auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
70     auto variance = std::make_shared<op::Divide>(xmu, N);
71     auto variance_label =
72         std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
73
74     return variance_label;
75 }
76
77 static std::shared_ptr<pattern::op::Label> construct_mean_graph()
78 {
79     // construct mean;
80     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
81     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
82     auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
83     auto mean = std::make_shared<op::Divide>(sum_input1, N);
84     auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
85     return mean_label;
86 }
87
88 class TestGraphRewrite : public ngraph::pass::GraphRewrite
89 {
90 public:
91     void construct_multiply_by_one()
92     {
93         // pattern #1 : a * 1 = a
94         auto iconst1 = construct_constant_node(1);
95         auto pattern = std::make_shared<pattern::op::Label>(iconst1);
96
97         auto callback = [pattern](pattern::Matcher& m) {
98             NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
99                          << m.get_match_root()->get_name();
100             NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
101
102             auto pattern_map = m.get_pattern_map();
103
104             size_t const_node_index =
105                 m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
106             auto const_node = as_type_ptr<op::Constant>(
107                 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
108             auto second_node =
109                 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
110             NGRAPH_DEBUG << "second_node = " << second_node->get_name()
111                          << " , pattern = " << pattern_map[pattern]->get_name();
112
113             if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
114                 pattern_map[pattern]->get_shape() != const_node->get_shape())
115             {
116                 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
117                 return false;
118             }
119
120             auto const_values = const_node->get_vector<int32_t>();
121             bool all_ones =
122                 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });
123
124             if (!all_ones)
125             {
126                 NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
127                 return false;
128             }
129
130             ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
131             return true;
132         };
133
134         auto m = make_shared<TestMatcher>(pattern * iconst1);
135         this->add_matcher(m, callback);
136     }
137
138     void construct_add_zero()
139     {
140         // pattern #2 : a + 0 = a
141         auto iconst0 = construct_constant_node(0);
142         auto pattern = std::make_shared<pattern::op::Label>(iconst0);
143
144         auto callback = [pattern](pattern::Matcher& m) {
145             NGRAPH_DEBUG << "In a callback for construct_add_zero against "
146                          << m.get_match_root()->get_name();
147             NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
148
149             auto pattern_map = m.get_pattern_map();
150
151             size_t const_node_index =
152                 m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
153             auto const_node = as_type_ptr<op::Constant>(
154                 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
155             auto second_node =
156                 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
157             NGRAPH_DEBUG << "second_node = " << second_node->get_name()
158                          << " , pattern = " << pattern_map[pattern]->get_name();
159
160             if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
161                 pattern_map[pattern]->get_shape() != const_node->get_shape())
162             {
163                 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
164                 return false;
165             }
166
167             auto const_values = const_node->get_vector<int>();
168             bool all_zeros =
169                 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });
170
171             if (!all_zeros)
172             {
173                 NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
174                 return false;
175             }
176
177             ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
178             return true;
179         };
180
181         auto add = pattern + iconst0;
182         auto m = make_shared<TestMatcher>(add);
183         this->add_matcher(m, callback);
184     }
185
186     TestGraphRewrite()
187         : GraphRewrite()
188     {
189         construct_multiply_by_one();
190         construct_add_zero();
191     }
192 };
193
194 static void run_passes(pass::Manager& pass_manager,
195                        shared_ptr<Node> graph,
196                        std::vector<shared_ptr<op::Parameter>> parms)
197 {
198     auto func = make_shared<Function>(graph, ParameterVector{parms});
199     pass_manager.run_passes(func);
200 }
201
202 TEST(pattern, graph_rewrite)
203 {
204     Shape shape{};
205     pass::Manager pass_manager;
206     pass_manager.register_pass<TestGraphRewrite>();
207
208     {
209         auto a = make_shared<op::Parameter>(element::i32, shape);
210         auto b = make_shared<op::Parameter>(element::i32, shape);
211         auto c = make_shared<op::Parameter>(element::i32, shape);
212         auto iconst0 = construct_constant_node(0);
213         auto graph_a = a + iconst0;
214         auto graph_b = b + iconst0;
215
216         auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
217                                             ParameterVector{a, b, c});
218         pass_manager.run_passes(f);
219
220         ASSERT_TRUE(graph_a->get_output_target_inputs(0).empty());
221         ASSERT_TRUE(graph_b->get_output_target_inputs(0).empty());
222
223         auto expected = ngraph::NodeVector{a, b, a, c, b};
224         ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
225     }
226
227     {
228         auto a = make_shared<op::Parameter>(element::i32, shape);
229         auto b = make_shared<op::Parameter>(element::i32, shape);
230         auto iconst0 = construct_constant_node(0);
231         auto sum = (a + iconst0);
232         auto graph = b + sum;
233         run_passes(pass_manager, graph, {a, b});
234         ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
235         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
236         ASSERT_TRUE(sum->output(0)
237                         .get_target_inputs()
238                         .empty()); // graph's input is removed from sum's target inptus
239         ASSERT_TRUE(a->get_output_target_inputs(0).count(
240             graph->input(1))); // a's output feeds into graph's input
241     }
242
243     {
244         auto a = make_shared<op::Parameter>(element::i32, shape);
245         auto b = make_shared<op::Parameter>(element::i32, shape);
246         auto iconst1 = construct_constant_node(1);
247         auto mul = (a * iconst1);
248         auto graph = b + mul;
249         run_passes(pass_manager, graph, {a, b});
250         ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
251         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
252         ASSERT_TRUE(mul->output(0)
253                         .get_target_inputs()
254                         .empty()); // graph's input is removed from sum's target inputs
255         ASSERT_TRUE(a->get_output_target_inputs(0).count(
256             graph->input(1))); // a's output feeds into graph's input
257     }
258
259     {
260         auto a = make_shared<op::Parameter>(element::i32, shape);
261         auto b = make_shared<op::Parameter>(element::i32, shape);
262         auto iconst1 = construct_constant_node(1);
263         auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
264         run_passes(pass_manager, graph, {a, b});
265         ASSERT_EQ(graph->input_value(0).get_node_shared_ptr(), a);
266         ASSERT_EQ(graph->input_value(0), a->output(0)); // graph's input points to a's output
267         ASSERT_TRUE(a->get_output_target_inputs(0).count(
268             graph->input(0))); // a's output feeds into graph's input
269     }
270
271     {
272         auto a = make_shared<op::Parameter>(element::i32, shape);
273         auto b = make_shared<op::Parameter>(element::i32, shape);
274         auto iconst0 = construct_constant_node(0);
275         auto iconst1 = construct_constant_node(1);
276         auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
277         run_passes(pass_manager, graph, {a, b});
278         ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
279         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
280         ASSERT_TRUE(a->get_output_target_inputs(0).count(
281             graph->input(1))); // a's output feeds into graph's input
282     }
283
284     {
285         auto a = make_shared<op::Parameter>(element::i32, shape);
286         auto b = make_shared<op::Parameter>(element::i32, shape);
287         auto iconst1 = construct_constant_node(1);
288         auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
289         run_passes(pass_manager, graph, {a, b});
290         ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
291         ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
292         ASSERT_TRUE(a->get_output_target_inputs(0).count(
293             graph->input(1))); // a's output feeds into graph's input
294     }
295 }
296
297 TEST(pattern, matcher)
298 {
299     Shape shape{};
300     auto a = make_shared<op::Parameter>(element::i32, shape);
301     TestMatcher n;
302     ASSERT_TRUE(n.match(a, a));
303     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
304
305     auto abs = make_shared<op::Abs>(a);
306     auto any = std::make_shared<pattern::op::Skip>(a);
307     ASSERT_TRUE(n.match(any, abs));
308     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
309
310     auto false_pred = [](std::shared_ptr<Node> /* no */) { return false; };
311     auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
312     ASSERT_TRUE(n.match(any_false, a));
313     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
314
315     auto pattern = std::make_shared<pattern::op::Label>(a);
316     ASSERT_TRUE(n.match(pattern, a));
317     ASSERT_EQ(n.get_pattern_map()[pattern], a);
318     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
319
320     auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
321     ASSERT_FALSE(n.match(pattern_false, a));
322     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
323
324     auto b = make_shared<op::Parameter>(element::i32, shape);
325
326     auto is_bea = [](std::shared_ptr<Node> node) -> bool {
327         return op::is_binary_elementwise_arithmetic(node);
328     };
329     auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
330     auto add_ab = a + b;
331     ASSERT_TRUE(n.match(bea, add_ab));
332     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
333     ASSERT_TRUE(n.match(bea, b + a));
334
335     auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
336     ASSERT_FALSE(n.match(bea_false, a + b));
337
338     auto add_abs_b = abs + b;
339     auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
340     ASSERT_TRUE(n.match(bea_any_of, add_abs_b));
341
342     auto add_b_abs = b + abs;
343     ASSERT_TRUE(n.match(bea_any_of, add_b_abs));
344
345     auto bea_any_of_label =
346         std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
347     ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
348     ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);
349
350     auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
351     auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
352     ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
353     ASSERT_EQ(n.get_pattern_map()[abs_label], abs);
354
355     auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
356     auto ab = a + b;
357     ASSERT_TRUE(n.match(bea_label, ab));
358     ASSERT_EQ(n.get_pattern_map()[bea_label], ab);
359
360     auto d = make_shared<op::Parameter>(element::i32, shape);
361     ASSERT_FALSE(n.match(d, b));
362
363     ASSERT_FALSE(n.match(abs + b, b + b));
364     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
365
366     auto add_absb = abs + b;
367     ASSERT_TRUE(n.match(any + b, add_absb));
368     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
369
370     ASSERT_TRUE(n.match(pattern + b, add_absb));
371     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
372     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
373
374     ASSERT_TRUE(n.match(b + pattern, add_absb));
375     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
376     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
377
378     auto c = make_shared<op::Parameter>(element::i32, shape);
379     auto mul_add_absb = c * (add_absb);
380     ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
381     ASSERT_EQ(n.get_pattern_map()[pattern], abs);
382     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
383
384     ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
385     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
386     ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
387     auto mul_c_add_ab = c * add_ab;
388     ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b)));  // nested any
389     ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
390     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
391
392     auto iconst1_0 = construct_constant_node(1);
393     auto iconst1_1 = construct_constant_node(1);
394     ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
395     ASSERT_EQ(n.get_pattern_map()[pattern], a);
396     auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
397     auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
398     ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
399
400     // Subgraph labels
401     auto add = a + b;
402     auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
403     ASSERT_TRUE(n.match(label, add));
404     ASSERT_EQ(n.get_pattern_map()[label], add);
405     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
406
407     ASSERT_FALSE(n.match(label, a - b));
408
409     ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
410     ASSERT_EQ(n.get_pattern_map()[label], add);
411
412     // Correct argument order
413     ASSERT_FALSE(n.match(b - a, a - b));
414     auto aab = a * (a - b);
415     auto paab = pattern * (pattern - b);
416     ASSERT_TRUE(n.match(paab, aab));
417     auto aba = a * (b - a);
418     ASSERT_FALSE(n.match(paab, aba));
419     auto paba = pattern * (b - pattern);
420     ASSERT_FALSE(n.match(paba, aab));
421
422     // Correlations
423     auto label1 = std::make_shared<pattern::op::Label>(a);
424     auto tmp = label1 + b;
425     auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
426     auto sub_label1 = label1 - label2;
427     auto sub_add = a - add;
428     ASSERT_TRUE(n.match(sub_label1, sub_add));
429     ASSERT_EQ(n.get_pattern_map()[label1], a);
430     ASSERT_EQ(n.get_pattern_map()[label2], add);
431     ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
432
433     ASSERT_FALSE(n.match(sub_label1, add - a));
434
435     auto add_label1 = label1 + label2;
436     ASSERT_TRUE(n.match(add_label1, add + a));
437     ASSERT_EQ(n.get_pattern_map()[label1], a);
438     ASSERT_EQ(n.get_pattern_map()[label2], add);
439
440     // Or
441     ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
442     ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));
443
444     // Branch
445     {
446         auto branch = std::make_shared<pattern::op::Branch>();
447         auto star = std::make_shared<pattern::op::Or>(
448             OutputVector{branch, std::make_shared<pattern::op::True>()});
449         auto pattern = star + star;
450         branch->set_destination(pattern);
451         ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
452         ASSERT_EQ(n.get_matched_nodes().size(), 4);
453     }
454
455     // strict mode
456     {
457         TestMatcher sm(Output<Node>{}, "TestMatcher", true);
458         // exact shape and type
459         auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
460         auto label_dynamic_shape =
461             make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
462         auto param = make_shared<op::Parameter>(element::f32, Shape{});
463         ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
464         // wrong type
465         auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
466         ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
467         // dynamic dimension
468         auto label_dynamic_dimension =
469             make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
470         auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
471         ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
472         // dynamic type
473         auto label_dynamic_type =
474             make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
475         ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
476     }
477 }
478
479 TEST(pattern, mean)
480 {
481     // construct mean
482     TestMatcher n;
483
484     auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
485     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
486     auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
487     auto mean = std::make_shared<op::Divide>(sum_input1, N);
488
489     auto mean_graph = construct_mean_graph();
490     ASSERT_TRUE(n.match(mean_graph, mean));
491     ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
492 }
493
494 TEST(pattern, variance)
495 {
496     // construct variance
497     TestMatcher n;
498     auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
499     auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
500     auto input_sq = std::make_shared<op::Multiply>(input, input);
501     auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
502     auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
503     auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
504     auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
505     auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
506     auto variance = std::make_shared<op::Divide>(xmu, N);
507
508     auto var_graph = construct_variance_graph();
509     ASSERT_TRUE(n.match(var_graph, variance));
510     ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
511 }
512
513 TEST(pattern, previous_matches)
514 {
515     using ngraph::pattern::Matcher;
516     Shape shape{};
517     Matcher::PatternMap previous_matches;
518     auto a = make_shared<op::Parameter>(element::i32, shape);
519     auto b = make_shared<op::Parameter>(element::i32, shape);
520     auto pattern = std::make_shared<pattern::op::Label>(b);
521     auto abs = make_shared<op::Abs>(a);
522     auto add = abs + b;
523     {
524         Matcher n(pattern + b);
525         ASSERT_TRUE(n.match(add, previous_matches));
526         ASSERT_EQ(n.get_pattern_map()[pattern], abs);
527     }
528
529     {
530         Matcher n(pattern + b);
531         previous_matches.insert(std::make_pair(pattern, a));
532         ASSERT_FALSE(n.match(add, previous_matches));
533     }
534 }
535
536 TEST(pattern, test_sort)
537 {
538     using ngraph::pattern::Matcher;
539     Shape shape{};
540
541     auto a = make_shared<op::Parameter>(element::i32, shape);
542     auto b = make_shared<op::Parameter>(element::i32, shape);
543     auto abs1 = make_shared<op::Abs>(a);
544     auto abs2 = make_shared<op::Abs>(b);
545     auto add = abs1 + abs2;
546
547     auto pa = make_shared<op::Parameter>(element::i32, shape);
548     auto pb = make_shared<op::Parameter>(element::i32, shape);
549     auto pabs1 = make_shared<op::Abs>(pa);
550     auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
551     auto pabs2 = make_shared<op::Abs>(b);
552     auto padd = pabs1_label + pabs2;
553
554     {
555         Matcher n1(padd);
556         ASSERT_TRUE(n1.match(add));
557         auto r1 = n1.get_pattern_map()[pabs1_label];
558         ASSERT_TRUE(n1.match(add));
559         ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
560     }
561 }
562
563 TEST(pattern, recurrent_pattern)
564 {
565     using ngraph::pattern::RecurrentMatcher;
566     Shape shape{};
567     ngraph::pattern::Matcher::PatternMap previous_matches;
568     auto a = make_shared<op::Parameter>(element::i32, shape);
569     auto b = make_shared<op::Parameter>(element::i32, shape);
570     auto rpattern = std::make_shared<pattern::op::Label>(b);
571     auto iconst0 = construct_constant_node(0);
572     auto abs = make_shared<op::Abs>(a);
573     auto add1 = iconst0 + b;
574     auto add2 = iconst0 + add1;
575     auto add3 = iconst0 + add2;
576     auto padd = iconst0 + rpattern;
577     std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
578     RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
579     ASSERT_TRUE(rm.match(add3));
580     ASSERT_EQ(rm.get_number_of_bound_labels(), 3);
581     auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
582     ASSERT_EQ(recurrent_matches.at(0), add2);
583     ASSERT_EQ(recurrent_matches.at(1), add1);
584     ASSERT_EQ(recurrent_matches.at(2), b);
585
586     // Multiple labels in a reccuring pattern
587     auto iconst1 = construct_constant_node(1);
588     auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
589     auto add2_2 = iconst1 + add1;
590     auto add3_2 = iconst0 + add2_2;
591     auto padd2 = iconst_label + rpattern;
592     RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
593     ASSERT_TRUE(rm2.match(add3_2));
594     ASSERT_EQ(rm2.get_number_of_bound_labels(), 4);
595     recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
596     ASSERT_EQ(recurrent_matches.at(0), add2_2);
597     ASSERT_EQ(recurrent_matches.at(1), add1);
598     ASSERT_EQ(recurrent_matches.at(2), b);
599     auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
600     ASSERT_EQ(iconst_matches.at(0), iconst0);
601     ASSERT_EQ(iconst_matches.at(1), iconst1);
602     ASSERT_EQ(iconst_matches.at(2), iconst0);
603
604     // Non-matching correlated labels
605     std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
606     correlated_matches.insert(iconst_label);
607     RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
608     ASSERT_TRUE(rm3.match(add3_2));
609     ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
610     iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
611     ASSERT_EQ(iconst_matches.size(), 1);
612     ASSERT_EQ(iconst_matches.at(0), iconst0);
613
614     // Matching correlated labels and
615     // testing if RecurrentMatcher can be reused for different nodes
616     ASSERT_TRUE(rm3.match(add3));
617     ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
618     recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
619     ASSERT_EQ(recurrent_matches.at(0), add2);
620     ASSERT_EQ(recurrent_matches.at(1), add1);
621     ASSERT_EQ(recurrent_matches.at(2), b);
622     iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
623     ASSERT_EQ(iconst_matches.at(0), iconst0);
624     ASSERT_EQ(iconst_matches.at(1), iconst0);
625     ASSERT_EQ(iconst_matches.at(2), iconst0);
626 }
627
628 class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
629 {
630 public:
631     void construct_recurrent_add()
632     {
633         Shape shape{};
634         auto iconst0 = construct_constant_node(0);
635         auto iconst_label =
636             std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
637         auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
638         auto padd = iconst_label + rpattern;
639
640         auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
641             NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
642                          << rm.get_match_root()->get_name();
643
644             auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
645
646             auto is_iconst_zero = [](std::shared_ptr<Node> n) {
647                 bool result = ngraph::is_zero(n);
648                 NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
649                 return ngraph::is_zero(n);
650             };
651
652             bool are_all_iconst_zeros =
653                 std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);
654
655             if (!are_all_iconst_zeros)
656             {
657                 return false;
658             }
659
660             auto number_of_adds = rm.get_number_of_recurrent_matches();
661             // replace the topmost add with the seed (i.e. the first parameter to add)
662             // matches are added in reverse order (i.e. the first match is the topmost node)
663             auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
664             NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
665                          << arg->get_name();
666             ngraph::replace_node(rm.get_match_root(), arg);
667             return true;
668         };
669
670         std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
671         auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
672         this->add_matcher(rm, callback);
673     }
674
675     TestRecurrentGraphRewrite()
676         : RecurrentGraphRewrite()
677     {
678         construct_recurrent_add();
679     }
680 };
681
682 TEST(pattern, recurrent_graph_rewrite)
683 {
684     Shape shape{};
685     pass::Manager pass_manager;
686     pass_manager.register_pass<TestRecurrentGraphRewrite>();
687
688     {
689         auto a = make_shared<op::Parameter>(element::i32, shape);
690         auto iconst0 = construct_constant_node(0);
691         auto add_a1 = a + iconst0;
692         auto add_a2 = add_a1 + iconst0;
693         auto add_a3 = add_a2 + iconst0;
694         auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);
695
696         auto b = make_shared<op::Parameter>(element::i32, shape);
697         auto add_b1 = b + iconst0;
698         auto add_b2 = add_b1 + iconst0;
699         auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);
700
701         auto graph = abs_add_a3 * abs_add_b2;
702
703         auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
704         pass_manager.run_passes(f);
705
706         auto left_abs = graph->input_value(0).get_node_shared_ptr();
707         auto add_a = left_abs->input_value(0).get_node_shared_ptr();
708         ASSERT_EQ(add_a, a);
709
710         auto right_abs = graph->input_value(1).get_node_shared_ptr();
711         auto add_b = right_abs->input_value(0).get_node_shared_ptr();
712         ASSERT_EQ(add_b, b);
713     }
714 }
715
716 TEST(pattern, label_on_skip)
717 {
718     Shape shape{2, 2};
719     auto a = make_shared<op::Parameter>(element::i32, shape);
720     auto b = make_shared<op::Parameter>(element::i32, Shape{});
721     auto iconst = ngraph::make_zero(element::i32, Shape{});
722     auto label = std::make_shared<pattern::op::Label>(iconst);
723     auto const_label =
724         std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
725
726     auto bcst_pred = [](std::shared_ptr<Node> n) {
727         return as_type_ptr<op::Broadcast>(n) != nullptr;
728     };
729
730     auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
731     auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
732     auto matcher = std::make_shared<pattern::Matcher>(
733         std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
734
735     auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
736     auto mul = a * const_broadcast;
737     auto mul_scalar = b * iconst;
738     ASSERT_TRUE(matcher->match(mul));
739     ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
740     ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
741     ASSERT_EQ(matcher->get_pattern_map()[label], a);
742     ASSERT_TRUE(matcher->match(mul_scalar));
743     ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
744     ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
745     ASSERT_EQ(matcher->get_pattern_map()[label], b);
746 }
747
748 TEST(pattern, is_contained_match)
749 {
750     Shape shape{};
751     auto a = make_shared<op::Parameter>(element::i32, shape);
752     auto absn = make_shared<op::Abs>(a);
753     TestMatcher n;
754
755     auto label_a = std::make_shared<pattern::op::Label>(a);
756     auto label_abs = make_shared<op::Abs>(a);
757     ASSERT_TRUE(n.match(label_abs, absn));
758     auto result_absn = make_shared<op::Result>(absn);
759     ASSERT_TRUE(n.is_contained_match());
760
761     auto absn2 = make_shared<op::Abs>(absn);
762     auto result_absn2 = make_shared<op::Result>(absn2);
763     auto label_abs2 = make_shared<op::Abs>(label_abs);
764     ASSERT_TRUE(n.match(label_abs2, absn2));
765     ASSERT_FALSE(n.is_contained_match());
766 }
767
768 TEST(pattern, wrap_type)
769 {
770     auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
771     auto b = make_shared<op::Abs>(a);
772     auto c = make_shared<op::Relu>(a);
773     auto mul1 = make_shared<op::v1::Multiply>(a, op::Constant::create(element::f32, Shape{}, {1}));
774     auto mul2 = make_shared<op::v1::Multiply>(op::Constant::create(element::f32, Shape{}, {1}), a);
775
776     {
777         auto m = pattern::wrap_type<op::Abs>();
778         auto matcher = std::make_shared<pattern::Matcher>(m, "AbsMatcher");
779         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
780         ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
781         ASSERT_EQ(matcher->get_matched_nodes()[0], b);
782         ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
783         ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
784     }
785     {
786         auto m1 = pattern::wrap_type<op::Parameter>();
787         auto m2 = pattern::wrap_type<op::Abs>({m1});
788         auto matcher = std::make_shared<pattern::Matcher>(m2, "ParamAbsMatcher");
789         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
790         ASSERT_EQ(matcher->get_matched_nodes().size(), 2);
791         ASSERT_EQ(matcher->get_pattern_map().count(m1), 1);
792         ASSERT_EQ(matcher->get_pattern_map().count(m2), 1);
793         ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
794     }
795     {
796         auto m1 = pattern::wrap_type<op::v1::Multiply>(
797             {pattern::any_input(), pattern::wrap_type<op::Constant>()});
798         auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
799         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
800         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
801     }
802     {
803         auto m1 = pattern::wrap_type<op::v1::Multiply>(
804             {pattern::wrap_type<op::Constant>(), pattern::any_input()});
805         auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
806         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
807         ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
808     }
809 }