1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
22 #include <ngraph/pattern/op/wrap_type.hpp>
24 #include "gtest/gtest.h"
25 #include "ngraph/file_util.hpp"
26 #include "ngraph/graph_util.hpp"
27 #include "ngraph/log.hpp"
28 #include "ngraph/ngraph.hpp"
29 #include "ngraph/op/add.hpp"
30 #include "ngraph/op/batch_norm.hpp"
31 #include "ngraph/op/constant.hpp"
32 #include "ngraph/op/divide.hpp"
33 #include "ngraph/op/multiply.hpp"
34 #include "ngraph/op/sqrt.hpp"
35 #include "ngraph/op/subtract.hpp"
36 #include "ngraph/op/sum.hpp"
37 #include "ngraph/op/sum.hpp"
38 #include "ngraph/op/util/op_types.hpp"
39 #include "ngraph/pass/graph_rewrite.hpp"
40 #include "ngraph/pass/manager.hpp"
41 #include "ngraph/pattern/matcher.hpp"
42 #include "ngraph/pattern/op/branch.hpp"
43 #include "ngraph/pattern/op/label.hpp"
44 #include "ngraph/pattern/op/or.hpp"
45 #include "ngraph/pattern/op/skip.hpp"
46 #include "ngraph/pattern/op/true.hpp"
47 #include "ngraph/serializer.hpp"
48 #include "util/matcher.hpp"
49 #include "util/test_tools.hpp"
51 using namespace ngraph;
54 static std::shared_ptr<Node> construct_constant_node(int n)
56 return op::Constant::create(element::i32, Shape{}, {n});
59 static std::shared_ptr<pattern::op::Label> construct_variance_graph()
61 // construct varaiance
62 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
63 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
64 auto input_sq = std::make_shared<op::Multiply>(input, input);
65 auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
66 auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
67 auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
68 auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
69 auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
70 auto variance = std::make_shared<op::Divide>(xmu, N);
72 std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
74 return variance_label;
77 static std::shared_ptr<pattern::op::Label> construct_mean_graph()
80 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
81 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
82 auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
83 auto mean = std::make_shared<op::Divide>(sum_input1, N);
84 auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
88 class TestGraphRewrite : public ngraph::pass::GraphRewrite
91 void construct_multiply_by_one()
93 // pattern #1 : a * 1 = a
94 auto iconst1 = construct_constant_node(1);
95 auto pattern = std::make_shared<pattern::op::Label>(iconst1);
97 auto callback = [pattern](pattern::Matcher& m) {
98 NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
99 << m.get_match_root()->get_name();
100 NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
102 auto pattern_map = m.get_pattern_map();
104 size_t const_node_index =
105 m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
106 auto const_node = as_type_ptr<op::Constant>(
107 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
109 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
110 NGRAPH_DEBUG << "second_node = " << second_node->get_name()
111 << " , pattern = " << pattern_map[pattern]->get_name();
113 if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
114 pattern_map[pattern]->get_shape() != const_node->get_shape())
116 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
120 auto const_values = const_node->get_vector<int32_t>();
122 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });
126 NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
130 ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
134 auto m = make_shared<TestMatcher>(pattern * iconst1);
135 this->add_matcher(m, callback);
138 void construct_add_zero()
140 // pattern #2 : a + 0 = a
141 auto iconst0 = construct_constant_node(0);
142 auto pattern = std::make_shared<pattern::op::Label>(iconst0);
144 auto callback = [pattern](pattern::Matcher& m) {
145 NGRAPH_DEBUG << "In a callback for construct_add_zero against "
146 << m.get_match_root()->get_name();
147 NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
149 auto pattern_map = m.get_pattern_map();
151 size_t const_node_index =
152 m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
153 auto const_node = as_type_ptr<op::Constant>(
154 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
156 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
157 NGRAPH_DEBUG << "second_node = " << second_node->get_name()
158 << " , pattern = " << pattern_map[pattern]->get_name();
160 if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
161 pattern_map[pattern]->get_shape() != const_node->get_shape())
163 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
167 auto const_values = const_node->get_vector<int>();
169 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });
173 NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
177 ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
181 auto add = pattern + iconst0;
182 auto m = make_shared<TestMatcher>(add);
183 this->add_matcher(m, callback);
189 construct_multiply_by_one();
190 construct_add_zero();
194 static void run_passes(pass::Manager& pass_manager,
195 shared_ptr<Node> graph,
196 std::vector<shared_ptr<op::Parameter>> parms)
198 auto func = make_shared<Function>(graph, ParameterVector{parms});
199 pass_manager.run_passes(func);
202 TEST(pattern, graph_rewrite)
205 pass::Manager pass_manager;
206 pass_manager.register_pass<TestGraphRewrite>();
209 auto a = make_shared<op::Parameter>(element::i32, shape);
210 auto b = make_shared<op::Parameter>(element::i32, shape);
211 auto c = make_shared<op::Parameter>(element::i32, shape);
212 auto iconst0 = construct_constant_node(0);
213 auto graph_a = a + iconst0;
214 auto graph_b = b + iconst0;
216 auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
217 ParameterVector{a, b, c});
218 pass_manager.run_passes(f);
220 ASSERT_TRUE(graph_a->get_output_target_inputs(0).empty());
221 ASSERT_TRUE(graph_b->get_output_target_inputs(0).empty());
223 auto expected = ngraph::NodeVector{a, b, a, c, b};
224 ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
228 auto a = make_shared<op::Parameter>(element::i32, shape);
229 auto b = make_shared<op::Parameter>(element::i32, shape);
230 auto iconst0 = construct_constant_node(0);
231 auto sum = (a + iconst0);
232 auto graph = b + sum;
233 run_passes(pass_manager, graph, {a, b});
234 ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
235 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
236 ASSERT_TRUE(sum->output(0)
238 .empty()); // graph's input is removed from sum's target inptus
239 ASSERT_TRUE(a->get_output_target_inputs(0).count(
240 graph->input(1))); // a's output feeds into graph's input
244 auto a = make_shared<op::Parameter>(element::i32, shape);
245 auto b = make_shared<op::Parameter>(element::i32, shape);
246 auto iconst1 = construct_constant_node(1);
247 auto mul = (a * iconst1);
248 auto graph = b + mul;
249 run_passes(pass_manager, graph, {a, b});
250 ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
251 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
252 ASSERT_TRUE(mul->output(0)
254 .empty()); // graph's input is removed from sum's target inputs
255 ASSERT_TRUE(a->get_output_target_inputs(0).count(
256 graph->input(1))); // a's output feeds into graph's input
260 auto a = make_shared<op::Parameter>(element::i32, shape);
261 auto b = make_shared<op::Parameter>(element::i32, shape);
262 auto iconst1 = construct_constant_node(1);
263 auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
264 run_passes(pass_manager, graph, {a, b});
265 ASSERT_EQ(graph->input_value(0).get_node_shared_ptr(), a);
266 ASSERT_EQ(graph->input_value(0), a->output(0)); // graph's input points to a's output
267 ASSERT_TRUE(a->get_output_target_inputs(0).count(
268 graph->input(0))); // a's output feeds into graph's input
272 auto a = make_shared<op::Parameter>(element::i32, shape);
273 auto b = make_shared<op::Parameter>(element::i32, shape);
274 auto iconst0 = construct_constant_node(0);
275 auto iconst1 = construct_constant_node(1);
276 auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
277 run_passes(pass_manager, graph, {a, b});
278 ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
279 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
280 ASSERT_TRUE(a->get_output_target_inputs(0).count(
281 graph->input(1))); // a's output feeds into graph's input
285 auto a = make_shared<op::Parameter>(element::i32, shape);
286 auto b = make_shared<op::Parameter>(element::i32, shape);
287 auto iconst1 = construct_constant_node(1);
288 auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
289 run_passes(pass_manager, graph, {a, b});
290 ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
291 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
292 ASSERT_TRUE(a->get_output_target_inputs(0).count(
293 graph->input(1))); // a's output feeds into graph's input
297 TEST(pattern, matcher)
300 auto a = make_shared<op::Parameter>(element::i32, shape);
302 ASSERT_TRUE(n.match(a, a));
303 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
305 auto abs = make_shared<op::Abs>(a);
306 auto any = std::make_shared<pattern::op::Skip>(a);
307 ASSERT_TRUE(n.match(any, abs));
308 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
310 auto false_pred = [](std::shared_ptr<Node> /* no */) { return false; };
311 auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
312 ASSERT_TRUE(n.match(any_false, a));
313 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
315 auto pattern = std::make_shared<pattern::op::Label>(a);
316 ASSERT_TRUE(n.match(pattern, a));
317 ASSERT_EQ(n.get_pattern_map()[pattern], a);
318 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
320 auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
321 ASSERT_FALSE(n.match(pattern_false, a));
322 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
324 auto b = make_shared<op::Parameter>(element::i32, shape);
326 auto is_bea = [](std::shared_ptr<Node> node) -> bool {
327 return op::is_binary_elementwise_arithmetic(node);
329 auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
331 ASSERT_TRUE(n.match(bea, add_ab));
332 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
333 ASSERT_TRUE(n.match(bea, b + a));
335 auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
336 ASSERT_FALSE(n.match(bea_false, a + b));
338 auto add_abs_b = abs + b;
339 auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
340 ASSERT_TRUE(n.match(bea_any_of, add_abs_b));
342 auto add_b_abs = b + abs;
343 ASSERT_TRUE(n.match(bea_any_of, add_b_abs));
345 auto bea_any_of_label =
346 std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
347 ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
348 ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);
350 auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
351 auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
352 ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
353 ASSERT_EQ(n.get_pattern_map()[abs_label], abs);
355 auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
357 ASSERT_TRUE(n.match(bea_label, ab));
358 ASSERT_EQ(n.get_pattern_map()[bea_label], ab);
360 auto d = make_shared<op::Parameter>(element::i32, shape);
361 ASSERT_FALSE(n.match(d, b));
363 ASSERT_FALSE(n.match(abs + b, b + b));
364 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
366 auto add_absb = abs + b;
367 ASSERT_TRUE(n.match(any + b, add_absb));
368 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
370 ASSERT_TRUE(n.match(pattern + b, add_absb));
371 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
372 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
374 ASSERT_TRUE(n.match(b + pattern, add_absb));
375 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
376 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
378 auto c = make_shared<op::Parameter>(element::i32, shape);
379 auto mul_add_absb = c * (add_absb);
380 ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
381 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
382 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
384 ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
385 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
386 ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
387 auto mul_c_add_ab = c * add_ab;
388 ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b))); // nested any
389 ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
390 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
392 auto iconst1_0 = construct_constant_node(1);
393 auto iconst1_1 = construct_constant_node(1);
394 ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
395 ASSERT_EQ(n.get_pattern_map()[pattern], a);
396 auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
397 auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
398 ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
402 auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
403 ASSERT_TRUE(n.match(label, add));
404 ASSERT_EQ(n.get_pattern_map()[label], add);
405 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
407 ASSERT_FALSE(n.match(label, a - b));
409 ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
410 ASSERT_EQ(n.get_pattern_map()[label], add);
412 // Correct argument order
413 ASSERT_FALSE(n.match(b - a, a - b));
414 auto aab = a * (a - b);
415 auto paab = pattern * (pattern - b);
416 ASSERT_TRUE(n.match(paab, aab));
417 auto aba = a * (b - a);
418 ASSERT_FALSE(n.match(paab, aba));
419 auto paba = pattern * (b - pattern);
420 ASSERT_FALSE(n.match(paba, aab));
423 auto label1 = std::make_shared<pattern::op::Label>(a);
424 auto tmp = label1 + b;
425 auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
426 auto sub_label1 = label1 - label2;
427 auto sub_add = a - add;
428 ASSERT_TRUE(n.match(sub_label1, sub_add));
429 ASSERT_EQ(n.get_pattern_map()[label1], a);
430 ASSERT_EQ(n.get_pattern_map()[label2], add);
431 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
433 ASSERT_FALSE(n.match(sub_label1, add - a));
435 auto add_label1 = label1 + label2;
436 ASSERT_TRUE(n.match(add_label1, add + a));
437 ASSERT_EQ(n.get_pattern_map()[label1], a);
438 ASSERT_EQ(n.get_pattern_map()[label2], add);
441 ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
442 ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));
446 auto branch = std::make_shared<pattern::op::Branch>();
447 auto star = std::make_shared<pattern::op::Or>(
448 OutputVector{branch, std::make_shared<pattern::op::True>()});
449 auto pattern = star + star;
450 branch->set_destination(pattern);
451 ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
452 ASSERT_EQ(n.get_matched_nodes().size(), 4);
457 TestMatcher sm(Output<Node>{}, "TestMatcher", true);
458 // exact shape and type
459 auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
460 auto label_dynamic_shape =
461 make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
462 auto param = make_shared<op::Parameter>(element::f32, Shape{});
463 ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
465 auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
466 ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
468 auto label_dynamic_dimension =
469 make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
470 auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
471 ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
473 auto label_dynamic_type =
474 make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
475 ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
484 auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
485 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
486 auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
487 auto mean = std::make_shared<op::Divide>(sum_input1, N);
489 auto mean_graph = construct_mean_graph();
490 ASSERT_TRUE(n.match(mean_graph, mean));
491 ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
494 TEST(pattern, variance)
496 // construct variance
498 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
499 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
500 auto input_sq = std::make_shared<op::Multiply>(input, input);
501 auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
502 auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
503 auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
504 auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
505 auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
506 auto variance = std::make_shared<op::Divide>(xmu, N);
508 auto var_graph = construct_variance_graph();
509 ASSERT_TRUE(n.match(var_graph, variance));
510 ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
513 TEST(pattern, previous_matches)
515 using ngraph::pattern::Matcher;
517 Matcher::PatternMap previous_matches;
518 auto a = make_shared<op::Parameter>(element::i32, shape);
519 auto b = make_shared<op::Parameter>(element::i32, shape);
520 auto pattern = std::make_shared<pattern::op::Label>(b);
521 auto abs = make_shared<op::Abs>(a);
524 Matcher n(pattern + b);
525 ASSERT_TRUE(n.match(add, previous_matches));
526 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
530 Matcher n(pattern + b);
531 previous_matches.insert(std::make_pair(pattern, a));
532 ASSERT_FALSE(n.match(add, previous_matches));
536 TEST(pattern, test_sort)
538 using ngraph::pattern::Matcher;
541 auto a = make_shared<op::Parameter>(element::i32, shape);
542 auto b = make_shared<op::Parameter>(element::i32, shape);
543 auto abs1 = make_shared<op::Abs>(a);
544 auto abs2 = make_shared<op::Abs>(b);
545 auto add = abs1 + abs2;
547 auto pa = make_shared<op::Parameter>(element::i32, shape);
548 auto pb = make_shared<op::Parameter>(element::i32, shape);
549 auto pabs1 = make_shared<op::Abs>(pa);
550 auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
551 auto pabs2 = make_shared<op::Abs>(b);
552 auto padd = pabs1_label + pabs2;
556 ASSERT_TRUE(n1.match(add));
557 auto r1 = n1.get_pattern_map()[pabs1_label];
558 ASSERT_TRUE(n1.match(add));
559 ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
563 TEST(pattern, recurrent_pattern)
565 using ngraph::pattern::RecurrentMatcher;
567 ngraph::pattern::Matcher::PatternMap previous_matches;
568 auto a = make_shared<op::Parameter>(element::i32, shape);
569 auto b = make_shared<op::Parameter>(element::i32, shape);
570 auto rpattern = std::make_shared<pattern::op::Label>(b);
571 auto iconst0 = construct_constant_node(0);
572 auto abs = make_shared<op::Abs>(a);
573 auto add1 = iconst0 + b;
574 auto add2 = iconst0 + add1;
575 auto add3 = iconst0 + add2;
576 auto padd = iconst0 + rpattern;
577 std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
578 RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
579 ASSERT_TRUE(rm.match(add3));
580 ASSERT_EQ(rm.get_number_of_bound_labels(), 3);
581 auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
582 ASSERT_EQ(recurrent_matches.at(0), add2);
583 ASSERT_EQ(recurrent_matches.at(1), add1);
584 ASSERT_EQ(recurrent_matches.at(2), b);
586 // Multiple labels in a reccuring pattern
587 auto iconst1 = construct_constant_node(1);
588 auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
589 auto add2_2 = iconst1 + add1;
590 auto add3_2 = iconst0 + add2_2;
591 auto padd2 = iconst_label + rpattern;
592 RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
593 ASSERT_TRUE(rm2.match(add3_2));
594 ASSERT_EQ(rm2.get_number_of_bound_labels(), 4);
595 recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
596 ASSERT_EQ(recurrent_matches.at(0), add2_2);
597 ASSERT_EQ(recurrent_matches.at(1), add1);
598 ASSERT_EQ(recurrent_matches.at(2), b);
599 auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
600 ASSERT_EQ(iconst_matches.at(0), iconst0);
601 ASSERT_EQ(iconst_matches.at(1), iconst1);
602 ASSERT_EQ(iconst_matches.at(2), iconst0);
604 // Non-matching correlated labels
605 std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
606 correlated_matches.insert(iconst_label);
607 RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
608 ASSERT_TRUE(rm3.match(add3_2));
609 ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
610 iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
611 ASSERT_EQ(iconst_matches.size(), 1);
612 ASSERT_EQ(iconst_matches.at(0), iconst0);
614 // Matching correlated labels and
615 // testing if RecurrentMatcher can be reused for different nodes
616 ASSERT_TRUE(rm3.match(add3));
617 ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
618 recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
619 ASSERT_EQ(recurrent_matches.at(0), add2);
620 ASSERT_EQ(recurrent_matches.at(1), add1);
621 ASSERT_EQ(recurrent_matches.at(2), b);
622 iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
623 ASSERT_EQ(iconst_matches.at(0), iconst0);
624 ASSERT_EQ(iconst_matches.at(1), iconst0);
625 ASSERT_EQ(iconst_matches.at(2), iconst0);
628 class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
631 void construct_recurrent_add()
634 auto iconst0 = construct_constant_node(0);
636 std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
637 auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
638 auto padd = iconst_label + rpattern;
640 auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
641 NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
642 << rm.get_match_root()->get_name();
644 auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
646 auto is_iconst_zero = [](std::shared_ptr<Node> n) {
647 bool result = ngraph::is_zero(n);
648 NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
649 return ngraph::is_zero(n);
652 bool are_all_iconst_zeros =
653 std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);
655 if (!are_all_iconst_zeros)
660 auto number_of_adds = rm.get_number_of_recurrent_matches();
661 // replace the topmost add with the seed (i.e. the first parameter to add)
662 // matches are added in reverse order (i.e. the first match is the topmost node)
663 auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
664 NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
666 ngraph::replace_node(rm.get_match_root(), arg);
670 std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
671 auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
672 this->add_matcher(rm, callback);
675 TestRecurrentGraphRewrite()
676 : RecurrentGraphRewrite()
678 construct_recurrent_add();
682 TEST(pattern, recurrent_graph_rewrite)
685 pass::Manager pass_manager;
686 pass_manager.register_pass<TestRecurrentGraphRewrite>();
689 auto a = make_shared<op::Parameter>(element::i32, shape);
690 auto iconst0 = construct_constant_node(0);
691 auto add_a1 = a + iconst0;
692 auto add_a2 = add_a1 + iconst0;
693 auto add_a3 = add_a2 + iconst0;
694 auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);
696 auto b = make_shared<op::Parameter>(element::i32, shape);
697 auto add_b1 = b + iconst0;
698 auto add_b2 = add_b1 + iconst0;
699 auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);
701 auto graph = abs_add_a3 * abs_add_b2;
703 auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
704 pass_manager.run_passes(f);
706 auto left_abs = graph->input_value(0).get_node_shared_ptr();
707 auto add_a = left_abs->input_value(0).get_node_shared_ptr();
710 auto right_abs = graph->input_value(1).get_node_shared_ptr();
711 auto add_b = right_abs->input_value(0).get_node_shared_ptr();
716 TEST(pattern, label_on_skip)
719 auto a = make_shared<op::Parameter>(element::i32, shape);
720 auto b = make_shared<op::Parameter>(element::i32, Shape{});
721 auto iconst = ngraph::make_zero(element::i32, Shape{});
722 auto label = std::make_shared<pattern::op::Label>(iconst);
724 std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
726 auto bcst_pred = [](std::shared_ptr<Node> n) {
727 return as_type_ptr<op::Broadcast>(n) != nullptr;
730 auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
731 auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
732 auto matcher = std::make_shared<pattern::Matcher>(
733 std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
735 auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
736 auto mul = a * const_broadcast;
737 auto mul_scalar = b * iconst;
738 ASSERT_TRUE(matcher->match(mul));
739 ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
740 ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
741 ASSERT_EQ(matcher->get_pattern_map()[label], a);
742 ASSERT_TRUE(matcher->match(mul_scalar));
743 ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
744 ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
745 ASSERT_EQ(matcher->get_pattern_map()[label], b);
748 TEST(pattern, is_contained_match)
751 auto a = make_shared<op::Parameter>(element::i32, shape);
752 auto absn = make_shared<op::Abs>(a);
755 auto label_a = std::make_shared<pattern::op::Label>(a);
756 auto label_abs = make_shared<op::Abs>(a);
757 ASSERT_TRUE(n.match(label_abs, absn));
758 auto result_absn = make_shared<op::Result>(absn);
759 ASSERT_TRUE(n.is_contained_match());
761 auto absn2 = make_shared<op::Abs>(absn);
762 auto result_absn2 = make_shared<op::Result>(absn2);
763 auto label_abs2 = make_shared<op::Abs>(label_abs);
764 ASSERT_TRUE(n.match(label_abs2, absn2));
765 ASSERT_FALSE(n.is_contained_match());
768 TEST(pattern, wrap_type)
770 auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
771 auto b = make_shared<op::Abs>(a);
772 auto c = make_shared<op::Relu>(a);
773 auto mul1 = make_shared<op::v1::Multiply>(a, op::Constant::create(element::f32, Shape{}, {1}));
774 auto mul2 = make_shared<op::v1::Multiply>(op::Constant::create(element::f32, Shape{}, {1}), a);
777 auto m = pattern::wrap_type<op::Abs>();
778 auto matcher = std::make_shared<pattern::Matcher>(m, "AbsMatcher");
779 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
780 ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
781 ASSERT_EQ(matcher->get_matched_nodes()[0], b);
782 ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
783 ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
786 auto m1 = pattern::wrap_type<op::Parameter>();
787 auto m2 = pattern::wrap_type<op::Abs>({m1});
788 auto matcher = std::make_shared<pattern::Matcher>(m2, "ParamAbsMatcher");
789 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
790 ASSERT_EQ(matcher->get_matched_nodes().size(), 2);
791 ASSERT_EQ(matcher->get_pattern_map().count(m1), 1);
792 ASSERT_EQ(matcher->get_pattern_map().count(m2), 1);
793 ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
796 auto m1 = pattern::wrap_type<op::v1::Multiply>(
797 {pattern::any_input(), pattern::wrap_type<op::Constant>()});
798 auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
799 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
800 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
803 auto m1 = pattern::wrap_type<op::v1::Multiply>(
804 {pattern::wrap_type<op::Constant>(), pattern::any_input()});
805 auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
806 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
807 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));