1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
23 #include "gtest/gtest.h"
24 #include "ngraph/file_util.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/log.hpp"
27 #include "ngraph/ngraph.hpp"
28 #include "ngraph/op/add.hpp"
29 #include "ngraph/op/batch_norm.hpp"
30 #include "ngraph/op/constant.hpp"
31 #include "ngraph/op/divide.hpp"
32 #include "ngraph/op/multiply.hpp"
33 #include "ngraph/op/sqrt.hpp"
34 #include "ngraph/op/subtract.hpp"
35 #include "ngraph/op/sum.hpp"
36 #include "ngraph/op/sum.hpp"
37 #include "ngraph/pass/graph_rewrite.hpp"
38 #include "ngraph/pass/manager.hpp"
39 #include "ngraph/pattern/matcher.hpp"
40 #include "ngraph/pattern/op/branch.hpp"
41 #include "ngraph/pattern/op/label.hpp"
42 #include "ngraph/pattern/op/or.hpp"
43 #include "ngraph/pattern/op/skip.hpp"
44 #include "ngraph/pattern/op/true.hpp"
45 #include "ngraph/serializer.hpp"
46 #include "util/matcher.hpp"
47 #include "util/test_tools.hpp"
49 using namespace ngraph;
52 static std::shared_ptr<Node> construct_constant_node(int n)
54 return op::Constant::create(element::i32, Shape{}, {n});
57 static std::shared_ptr<pattern::op::Label> construct_variance_graph()
59 // construct varaiance
60 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
61 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
62 auto input_sq = std::make_shared<op::Multiply>(input, input);
63 auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
64 auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
65 auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
66 auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
67 auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
68 auto variance = std::make_shared<op::Divide>(xmu, N);
70 std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
72 return variance_label;
75 static std::shared_ptr<pattern::op::Label> construct_mean_graph()
78 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
79 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
80 auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
81 auto mean = std::make_shared<op::Divide>(sum_input1, N);
82 auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
86 class TestGraphRewrite : public ngraph::pass::GraphRewrite
89 void construct_multiply_by_one()
91 // pattern #1 : a * 1 = a
92 auto iconst1 = construct_constant_node(1);
93 auto pattern = std::make_shared<pattern::op::Label>(iconst1);
95 auto callback = [pattern](pattern::Matcher& m) {
96 NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
97 << m.get_match_root()->get_name();
98 NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
100 auto pattern_map = m.get_pattern_map();
102 size_t const_node_index =
103 m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
105 as_type_ptr<op::Constant>(m.get_match_root()->get_arguments().at(const_node_index));
106 auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
107 NGRAPH_DEBUG << "second_node = " << second_node->get_name()
108 << " , pattern = " << pattern_map[pattern]->get_name();
110 if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
111 pattern_map[pattern]->get_shape() != const_node->get_shape())
113 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
117 auto const_values = const_node->get_vector<int32_t>();
119 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });
123 NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
127 ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
131 auto m = make_shared<TestMatcher>(pattern * iconst1);
132 this->add_matcher(m, callback);
135 void construct_add_zero()
137 // pattern #2 : a + 0 = a
138 auto iconst0 = construct_constant_node(0);
139 auto pattern = std::make_shared<pattern::op::Label>(iconst0);
141 auto callback = [pattern](pattern::Matcher& m) {
142 NGRAPH_DEBUG << "In a callback for construct_add_zero against "
143 << m.get_match_root()->get_name();
144 NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
146 auto pattern_map = m.get_pattern_map();
148 size_t const_node_index =
149 m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
151 as_type_ptr<op::Constant>(m.get_match_root()->get_arguments().at(const_node_index));
152 auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
153 NGRAPH_DEBUG << "second_node = " << second_node->get_name()
154 << " , pattern = " << pattern_map[pattern]->get_name();
156 if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
157 pattern_map[pattern]->get_shape() != const_node->get_shape())
159 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
163 auto const_values = const_node->get_vector<int>();
165 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });
169 NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
173 ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
177 auto add = pattern + iconst0;
178 auto m = make_shared<TestMatcher>(add);
179 this->add_matcher(m, callback);
185 construct_multiply_by_one();
186 construct_add_zero();
190 static void run_passes(pass::Manager& pass_manager,
191 shared_ptr<Node> graph,
192 std::vector<shared_ptr<op::Parameter>> parms)
194 auto func = make_shared<Function>(graph, ParameterVector{parms});
195 pass_manager.run_passes(func);
198 TEST(pattern, graph_rewrite)
201 pass::Manager pass_manager;
202 pass_manager.register_pass<TestGraphRewrite>();
205 auto a = make_shared<op::Parameter>(element::i32, shape);
206 auto b = make_shared<op::Parameter>(element::i32, shape);
207 auto c = make_shared<op::Parameter>(element::i32, shape);
208 auto iconst0 = construct_constant_node(0);
209 auto graph_a = a + iconst0;
210 auto graph_b = b + iconst0;
212 auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
213 ParameterVector{a, b, c});
214 pass_manager.run_passes(f);
216 ASSERT_TRUE(graph_a->get_output_target_inputs(0).empty());
217 ASSERT_TRUE(graph_b->get_output_target_inputs(0).empty());
219 auto expected = ngraph::NodeVector{a, b, a, c, b};
220 ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
224 auto a = make_shared<op::Parameter>(element::i32, shape);
225 auto b = make_shared<op::Parameter>(element::i32, shape);
226 auto iconst0 = construct_constant_node(0);
227 auto sum = (a + iconst0);
228 auto graph = b + sum;
229 run_passes(pass_manager, graph, {a, b});
230 ASSERT_EQ(graph->get_arguments().at(1), a);
231 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
232 ASSERT_TRUE(sum->output(0)
234 .empty()); // graph's input is removed from sum's target inptus
235 ASSERT_TRUE(a->get_output_target_inputs(0).count(
236 graph->input(1))); // a's output feeds into graph's input
240 auto a = make_shared<op::Parameter>(element::i32, shape);
241 auto b = make_shared<op::Parameter>(element::i32, shape);
242 auto iconst1 = construct_constant_node(1);
243 auto mul = (a * iconst1);
244 auto graph = b + mul;
245 run_passes(pass_manager, graph, {a, b});
246 ASSERT_EQ(graph->get_arguments().at(1), a);
247 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
248 ASSERT_TRUE(mul->output(0)
250 .empty()); // graph's input is removed from sum's target inputs
251 ASSERT_TRUE(a->get_output_target_inputs(0).count(
252 graph->input(1))); // a's output feeds into graph's input
256 auto a = make_shared<op::Parameter>(element::i32, shape);
257 auto b = make_shared<op::Parameter>(element::i32, shape);
258 auto iconst1 = construct_constant_node(1);
259 auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
260 run_passes(pass_manager, graph, {a, b});
261 ASSERT_EQ(graph->get_arguments().at(0), a);
262 ASSERT_EQ(graph->input_value(0), a->output(0)); // graph's input points to a's output
263 ASSERT_TRUE(a->get_output_target_inputs(0).count(
264 graph->input(0))); // a's output feeds into graph's input
268 auto a = make_shared<op::Parameter>(element::i32, shape);
269 auto b = make_shared<op::Parameter>(element::i32, shape);
270 auto iconst0 = construct_constant_node(0);
271 auto iconst1 = construct_constant_node(1);
272 auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
273 run_passes(pass_manager, graph, {a, b});
274 ASSERT_EQ(graph->get_arguments().at(1), a);
275 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
276 ASSERT_TRUE(a->get_output_target_inputs(0).count(
277 graph->input(1))); // a's output feeds into graph's input
281 auto a = make_shared<op::Parameter>(element::i32, shape);
282 auto b = make_shared<op::Parameter>(element::i32, shape);
283 auto iconst1 = construct_constant_node(1);
284 auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
285 run_passes(pass_manager, graph, {a, b});
286 ASSERT_EQ(graph->get_arguments().at(1), a);
287 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
288 ASSERT_TRUE(a->get_output_target_inputs(0).count(
289 graph->input(1))); // a's output feeds into graph's input
293 TEST(pattern, matcher)
296 auto a = make_shared<op::Parameter>(element::i32, shape);
298 ASSERT_TRUE(n.match(a, a));
299 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
301 auto abs = make_shared<op::Abs>(a);
302 auto any = std::make_shared<pattern::op::Skip>(a);
303 ASSERT_TRUE(n.match(any, abs));
304 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
306 auto false_pred = [](std::shared_ptr<Node> /* no */) { return false; };
307 auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
308 ASSERT_TRUE(n.match(any_false, a));
309 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
311 auto pattern = std::make_shared<pattern::op::Label>(a);
312 ASSERT_TRUE(n.match(pattern, a));
313 ASSERT_EQ(n.get_pattern_map()[pattern], a);
314 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
316 auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
317 ASSERT_FALSE(n.match(pattern_false, a));
318 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
320 auto b = make_shared<op::Parameter>(element::i32, shape);
322 auto is_bea = [](std::shared_ptr<Node> node) -> bool {
323 return node->is_binary_elementwise_arithmetic();
325 auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
327 ASSERT_TRUE(n.match(bea, add_ab));
328 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
329 ASSERT_TRUE(n.match(bea, b + a));
331 auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
332 ASSERT_FALSE(n.match(bea_false, a + b));
334 auto add_abs_b = abs + b;
335 auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
336 ASSERT_TRUE(n.match(bea_any_of, add_abs_b));
338 auto add_b_abs = b + abs;
339 ASSERT_TRUE(n.match(bea_any_of, add_b_abs));
341 auto bea_any_of_label =
342 std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
343 ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
344 ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);
346 auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
347 auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
348 ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
349 ASSERT_EQ(n.get_pattern_map()[abs_label], abs);
351 auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
353 ASSERT_TRUE(n.match(bea_label, ab));
354 ASSERT_EQ(n.get_pattern_map()[bea_label], ab);
356 auto d = make_shared<op::Parameter>(element::i32, shape);
357 ASSERT_FALSE(n.match(d, b));
359 ASSERT_FALSE(n.match(abs + b, b + b));
360 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
362 auto add_absb = abs + b;
363 ASSERT_TRUE(n.match(any + b, add_absb));
364 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
366 ASSERT_TRUE(n.match(pattern + b, add_absb));
367 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
368 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
370 ASSERT_TRUE(n.match(b + pattern, add_absb));
371 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
372 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
374 auto c = make_shared<op::Parameter>(element::i32, shape);
375 auto mul_add_absb = c * (add_absb);
376 ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
377 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
378 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
380 ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
381 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
382 ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
383 auto mul_c_add_ab = c * add_ab;
384 ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b))); // nested any
385 ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
386 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
388 auto iconst1_0 = construct_constant_node(1);
389 auto iconst1_1 = construct_constant_node(1);
390 ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
391 ASSERT_EQ(n.get_pattern_map()[pattern], a);
392 auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
393 auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
394 ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
398 auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
399 ASSERT_TRUE(n.match(label, add));
400 ASSERT_EQ(n.get_pattern_map()[label], add);
401 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
403 ASSERT_FALSE(n.match(label, a - b));
405 ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
406 ASSERT_EQ(n.get_pattern_map()[label], add);
408 // Correct argument order
409 ASSERT_FALSE(n.match(b - a, a - b));
410 auto aab = a * (a - b);
411 auto paab = pattern * (pattern - b);
412 ASSERT_TRUE(n.match(paab, aab));
413 auto aba = a * (b - a);
414 ASSERT_FALSE(n.match(paab, aba));
415 auto paba = pattern * (b - pattern);
416 ASSERT_FALSE(n.match(paba, aab));
419 auto label1 = std::make_shared<pattern::op::Label>(a);
420 auto tmp = label1 + b;
421 auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
422 auto sub_label1 = label1 - label2;
423 auto sub_add = a - add;
424 ASSERT_TRUE(n.match(sub_label1, sub_add));
425 ASSERT_EQ(n.get_pattern_map()[label1], a);
426 ASSERT_EQ(n.get_pattern_map()[label2], add);
427 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
429 ASSERT_FALSE(n.match(sub_label1, add - a));
431 auto add_label1 = label1 + label2;
432 ASSERT_TRUE(n.match(add_label1, add + a));
433 ASSERT_EQ(n.get_pattern_map()[label1], a);
434 ASSERT_EQ(n.get_pattern_map()[label2], add);
437 ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
438 ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));
442 auto branch = std::make_shared<pattern::op::Branch>();
443 auto star = std::make_shared<pattern::op::Or>(
444 OutputVector{branch, std::make_shared<pattern::op::True>()});
445 auto pattern = star + star;
446 branch->set_destination(pattern);
447 ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
448 ASSERT_EQ(n.get_matched_nodes().size(), 4);
453 TestMatcher sm(Output<Node>{}, "TestMatcher", true);
454 // exact shape and type
455 auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
456 auto label_dynamic_shape =
457 make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
458 auto param = make_shared<op::Parameter>(element::f32, Shape{});
459 ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
461 auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
462 ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
464 auto label_dynamic_dimension =
465 make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
466 auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
467 ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
469 auto label_dynamic_type =
470 make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
471 ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
480 auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
481 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
482 auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
483 auto mean = std::make_shared<op::Divide>(sum_input1, N);
485 auto mean_graph = construct_mean_graph();
486 ASSERT_TRUE(n.match(mean_graph, mean));
487 ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
490 TEST(pattern, variance)
492 // construct variance
494 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
495 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
496 auto input_sq = std::make_shared<op::Multiply>(input, input);
497 auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
498 auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
499 auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
500 auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
501 auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
502 auto variance = std::make_shared<op::Divide>(xmu, N);
504 auto var_graph = construct_variance_graph();
505 ASSERT_TRUE(n.match(var_graph, variance));
506 ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
509 TEST(pattern, previous_matches)
511 using ngraph::pattern::Matcher;
513 Matcher::PatternMap previous_matches;
514 auto a = make_shared<op::Parameter>(element::i32, shape);
515 auto b = make_shared<op::Parameter>(element::i32, shape);
516 auto pattern = std::make_shared<pattern::op::Label>(b);
517 auto abs = make_shared<op::Abs>(a);
520 Matcher n(pattern + b);
521 ASSERT_TRUE(n.match(add, previous_matches));
522 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
526 Matcher n(pattern + b);
527 previous_matches.insert(std::make_pair(pattern, a));
528 ASSERT_FALSE(n.match(add, previous_matches));
532 TEST(pattern, test_sort)
534 using ngraph::pattern::Matcher;
537 auto a = make_shared<op::Parameter>(element::i32, shape);
538 auto b = make_shared<op::Parameter>(element::i32, shape);
539 auto abs1 = make_shared<op::Abs>(a);
540 auto abs2 = make_shared<op::Abs>(b);
541 auto add = abs1 + abs2;
543 auto pa = make_shared<op::Parameter>(element::i32, shape);
544 auto pb = make_shared<op::Parameter>(element::i32, shape);
545 auto pabs1 = make_shared<op::Abs>(pa);
546 auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
547 auto pabs2 = make_shared<op::Abs>(b);
548 auto padd = pabs1_label + pabs2;
552 ASSERT_TRUE(n1.match(add));
553 auto r1 = n1.get_pattern_map()[pabs1_label];
554 ASSERT_TRUE(n1.match(add));
555 ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
559 TEST(pattern, recurrent_pattern)
561 using ngraph::pattern::RecurrentMatcher;
563 ngraph::pattern::Matcher::PatternMap previous_matches;
564 auto a = make_shared<op::Parameter>(element::i32, shape);
565 auto b = make_shared<op::Parameter>(element::i32, shape);
566 auto rpattern = std::make_shared<pattern::op::Label>(b);
567 auto iconst0 = construct_constant_node(0);
568 auto abs = make_shared<op::Abs>(a);
569 auto add1 = iconst0 + b;
570 auto add2 = iconst0 + add1;
571 auto add3 = iconst0 + add2;
572 auto padd = iconst0 + rpattern;
573 std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
574 RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
575 ASSERT_TRUE(rm.match(add3));
576 ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
577 auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
578 ASSERT_EQ(recurrent_matches.at(0), add2);
579 ASSERT_EQ(recurrent_matches.at(1), add1);
580 ASSERT_EQ(recurrent_matches.at(2), b);
582 // Multiple labels in a reccuring pattern
583 auto iconst1 = construct_constant_node(1);
584 auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
585 auto add2_2 = iconst1 + add1;
586 auto add3_2 = iconst0 + add2_2;
587 auto padd2 = iconst_label + rpattern;
588 RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
589 ASSERT_TRUE(rm2.match(add3_2));
590 ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
591 recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
592 ASSERT_EQ(recurrent_matches.at(0), add2_2);
593 ASSERT_EQ(recurrent_matches.at(1), add1);
594 ASSERT_EQ(recurrent_matches.at(2), b);
595 auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
596 ASSERT_EQ(iconst_matches.at(0), iconst0);
597 ASSERT_EQ(iconst_matches.at(1), iconst1);
598 ASSERT_EQ(iconst_matches.at(2), iconst0);
600 // Non-matching correlated labels
601 std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
602 correlated_matches.insert(iconst_label);
603 RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
604 ASSERT_TRUE(rm3.match(add3_2));
605 ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
606 iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
607 ASSERT_EQ(iconst_matches.size(), 1);
608 ASSERT_EQ(iconst_matches.at(0), iconst0);
610 // Matching correlated labels and
611 // testing if RecurrentMatcher can be reused for different nodes
612 ASSERT_TRUE(rm3.match(add3));
613 ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
614 recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
615 ASSERT_EQ(recurrent_matches.at(0), add2);
616 ASSERT_EQ(recurrent_matches.at(1), add1);
617 ASSERT_EQ(recurrent_matches.at(2), b);
618 iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
619 ASSERT_EQ(iconst_matches.at(0), iconst0);
620 ASSERT_EQ(iconst_matches.at(1), iconst0);
621 ASSERT_EQ(iconst_matches.at(2), iconst0);
624 class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
627 void construct_recurrent_add()
630 auto iconst0 = construct_constant_node(0);
632 std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
633 auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
634 auto padd = iconst_label + rpattern;
636 auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
637 NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
638 << rm.get_match_root()->get_name();
640 auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
642 auto is_iconst_zero = [](std::shared_ptr<Node> n) {
643 bool result = ngraph::is_zero(n);
644 NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
645 return ngraph::is_zero(n);
648 bool are_all_iconst_zeros =
649 std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);
651 if (!are_all_iconst_zeros)
656 auto number_of_adds = rm.get_number_of_recurrent_matches();
657 // replace the topmost add with the seed (i.e. the first parameter to add)
658 // matches are added in reverse order (i.e. the first match is the topmost node)
659 auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
660 NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
662 ngraph::replace_node(rm.get_match_root(), arg);
666 std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
667 auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
668 this->add_matcher(rm, callback);
671 TestRecurrentGraphRewrite()
672 : RecurrentGraphRewrite()
674 construct_recurrent_add();
678 TEST(pattern, recurrent_graph_rewrite)
681 pass::Manager pass_manager;
682 pass_manager.register_pass<TestRecurrentGraphRewrite>();
685 auto a = make_shared<op::Parameter>(element::i32, shape);
686 auto iconst0 = construct_constant_node(0);
687 auto add_a1 = a + iconst0;
688 auto add_a2 = add_a1 + iconst0;
689 auto add_a3 = add_a2 + iconst0;
690 auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);
692 auto b = make_shared<op::Parameter>(element::i32, shape);
693 auto add_b1 = b + iconst0;
694 auto add_b2 = add_b1 + iconst0;
695 auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);
697 auto graph = abs_add_a3 * abs_add_b2;
699 auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
700 pass_manager.run_passes(f);
702 auto left_abs = graph->get_argument(0);
703 auto add_a = left_abs->get_argument(0);
706 auto right_abs = graph->get_argument(1);
707 auto add_b = right_abs->get_argument(0);
712 TEST(pattern, label_on_skip)
715 auto a = make_shared<op::Parameter>(element::i32, shape);
716 auto b = make_shared<op::Parameter>(element::i32, Shape{});
717 auto iconst = ngraph::make_zero(element::i32, Shape{});
718 auto label = std::make_shared<pattern::op::Label>(iconst);
720 std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
722 auto bcst_pred = [](std::shared_ptr<Node> n) {
723 return as_type_ptr<op::Broadcast>(n) != nullptr;
726 auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
727 auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
728 auto matcher = std::make_shared<pattern::Matcher>(
729 std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
731 auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
732 auto mul = a * const_broadcast;
733 auto mul_scalar = b * iconst;
734 ASSERT_TRUE(matcher->match(mul));
735 ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
736 ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
737 ASSERT_EQ(matcher->get_pattern_map()[label], a);
738 ASSERT_TRUE(matcher->match(mul_scalar));
739 ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
740 ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
741 ASSERT_EQ(matcher->get_pattern_map()[label], b);
744 TEST(pattern, is_contained_match)
747 auto a = make_shared<op::Parameter>(element::i32, shape);
748 auto absn = make_shared<op::Abs>(a);
751 auto label_a = std::make_shared<pattern::op::Label>(a);
752 auto label_abs = make_shared<op::Abs>(a);
753 ASSERT_TRUE(n.match(label_abs, absn));
754 auto result_absn = make_shared<op::Result>(absn);
755 ASSERT_TRUE(n.is_contained_match());
757 auto absn2 = make_shared<op::Abs>(absn);
758 auto result_absn2 = make_shared<op::Result>(absn2);
759 auto label_abs2 = make_shared<op::Abs>(label_abs);
760 ASSERT_TRUE(n.match(label_abs2, absn2));
761 ASSERT_FALSE(n.is_contained_match());