1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
22 #include <ngraph/pattern/op/wrap_type.hpp>
24 #include "gtest/gtest.h"
25 #include "ngraph/file_util.hpp"
26 #include "ngraph/graph_util.hpp"
27 #include "ngraph/log.hpp"
28 #include "ngraph/ngraph.hpp"
29 #include "ngraph/op/add.hpp"
30 #include "ngraph/op/batch_norm.hpp"
31 #include "ngraph/op/constant.hpp"
32 #include "ngraph/op/divide.hpp"
33 #include "ngraph/op/multiply.hpp"
34 #include "ngraph/op/sqrt.hpp"
35 #include "ngraph/op/subtract.hpp"
36 #include "ngraph/op/sum.hpp"
37 #include "ngraph/op/sum.hpp"
38 #include "ngraph/op/util/op_types.hpp"
39 #include "ngraph/pass/graph_rewrite.hpp"
40 #include "ngraph/pass/manager.hpp"
41 #include "ngraph/pattern/matcher.hpp"
42 #include "ngraph/pattern/op/branch.hpp"
43 #include "ngraph/pattern/op/label.hpp"
44 #include "ngraph/pattern/op/or.hpp"
45 #include "ngraph/pattern/op/skip.hpp"
46 #include "ngraph/pattern/op/true.hpp"
47 #include "util/matcher.hpp"
48 #include "util/test_tools.hpp"
50 NGRAPH_SUPPRESS_DEPRECATED_START
52 using namespace ngraph;
55 static std::shared_ptr<Node> construct_constant_node(int n)
57 return op::Constant::create(element::i32, Shape{}, {n});
60 static std::shared_ptr<pattern::op::Label> construct_variance_graph()
62 // construct varaiance
63 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
64 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
65 auto input_sq = std::make_shared<op::Multiply>(input, input);
66 auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
67 auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
68 auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
69 auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
70 auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
71 auto variance = std::make_shared<op::Divide>(xmu, N);
73 std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
75 return variance_label;
78 static std::shared_ptr<pattern::op::Label> construct_mean_graph()
81 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
82 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
83 auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
84 auto mean = std::make_shared<op::Divide>(sum_input1, N);
85 auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
89 class TestGraphRewrite : public ngraph::pass::GraphRewrite
92 void construct_multiply_by_one()
94 // pattern #1 : a * 1 = a
95 auto iconst1 = construct_constant_node(1);
96 auto pattern = std::make_shared<pattern::op::Label>(iconst1);
98 auto callback = [pattern](pattern::Matcher& m) {
99 NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
100 << m.get_match_root()->get_name();
101 NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
103 auto pattern_map = m.get_pattern_map();
105 size_t const_node_index =
106 m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
107 auto const_node = as_type_ptr<op::Constant>(
108 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
110 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
111 NGRAPH_DEBUG << "second_node = " << second_node->get_name()
112 << " , pattern = " << pattern_map[pattern]->get_name();
114 if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
115 pattern_map[pattern]->get_shape() != const_node->get_shape())
117 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
121 auto const_values = const_node->get_vector<int32_t>();
123 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });
127 NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
131 ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
135 auto m = make_shared<TestMatcher>(pattern * iconst1);
136 NGRAPH_SUPPRESS_DEPRECATED_START
137 this->add_matcher(m, callback);
138 NGRAPH_SUPPRESS_DEPRECATED_END
141 void construct_add_zero()
143 // pattern #2 : a + 0 = a
144 auto iconst0 = construct_constant_node(0);
145 auto pattern = std::make_shared<pattern::op::Label>(iconst0);
147 auto callback = [pattern](pattern::Matcher& m) {
148 NGRAPH_DEBUG << "In a callback for construct_add_zero against "
149 << m.get_match_root()->get_name();
150 NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
152 auto pattern_map = m.get_pattern_map();
154 size_t const_node_index =
155 m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
156 auto const_node = as_type_ptr<op::Constant>(
157 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
159 m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
160 NGRAPH_DEBUG << "second_node = " << second_node->get_name()
161 << " , pattern = " << pattern_map[pattern]->get_name();
163 if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
164 pattern_map[pattern]->get_shape() != const_node->get_shape())
166 NGRAPH_DEBUG << "Operands' types and/or shape don't match";
170 auto const_values = const_node->get_vector<int>();
172 std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });
176 NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
180 ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
184 auto add = pattern + iconst0;
185 auto m = make_shared<TestMatcher>(add);
186 NGRAPH_SUPPRESS_DEPRECATED_START
187 this->add_matcher(m, callback);
188 NGRAPH_SUPPRESS_DEPRECATED_END
194 construct_multiply_by_one();
195 construct_add_zero();
199 static void run_passes(pass::Manager& pass_manager,
200 shared_ptr<Node> graph,
201 std::vector<shared_ptr<op::Parameter>> parms)
203 auto func = make_shared<Function>(graph, ParameterVector{parms});
204 pass_manager.run_passes(func);
207 TEST(pattern, graph_rewrite)
210 pass::Manager pass_manager;
211 pass_manager.register_pass<TestGraphRewrite>();
214 auto a = make_shared<op::Parameter>(element::i32, shape);
215 auto b = make_shared<op::Parameter>(element::i32, shape);
216 auto c = make_shared<op::Parameter>(element::i32, shape);
217 auto iconst0 = construct_constant_node(0);
218 auto graph_a = a + iconst0;
219 auto graph_b = b + iconst0;
221 auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
222 ParameterVector{a, b, c});
223 pass_manager.run_passes(f);
225 ASSERT_TRUE(graph_a->get_output_target_inputs(0).empty());
226 ASSERT_TRUE(graph_b->get_output_target_inputs(0).empty());
228 auto expected = ngraph::NodeVector{a, b, a, c, b};
229 ASSERT_TRUE(count_ops_of_type<op::Add>(f) == 0);
233 auto a = make_shared<op::Parameter>(element::i32, shape);
234 auto b = make_shared<op::Parameter>(element::i32, shape);
235 auto iconst0 = construct_constant_node(0);
236 auto sum = (a + iconst0);
237 auto graph = b + sum;
238 run_passes(pass_manager, graph, {a, b});
239 ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
240 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
241 ASSERT_TRUE(sum->output(0)
243 .empty()); // graph's input is removed from sum's target inptus
244 ASSERT_TRUE(a->get_output_target_inputs(0).count(
245 graph->input(1))); // a's output feeds into graph's input
249 auto a = make_shared<op::Parameter>(element::i32, shape);
250 auto b = make_shared<op::Parameter>(element::i32, shape);
251 auto iconst1 = construct_constant_node(1);
252 auto mul = (a * iconst1);
253 auto graph = b + mul;
254 run_passes(pass_manager, graph, {a, b});
255 ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
256 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
257 ASSERT_TRUE(mul->output(0)
259 .empty()); // graph's input is removed from sum's target inputs
260 ASSERT_TRUE(a->get_output_target_inputs(0).count(
261 graph->input(1))); // a's output feeds into graph's input
265 auto a = make_shared<op::Parameter>(element::i32, shape);
266 auto b = make_shared<op::Parameter>(element::i32, shape);
267 auto iconst1 = construct_constant_node(1);
268 auto graph = ((((a * iconst1) * iconst1) * iconst1) * iconst1) + b;
269 run_passes(pass_manager, graph, {a, b});
270 ASSERT_EQ(graph->input_value(0).get_node_shared_ptr(), a);
271 ASSERT_EQ(graph->input_value(0), a->output(0)); // graph's input points to a's output
272 ASSERT_TRUE(a->get_output_target_inputs(0).count(
273 graph->input(0))); // a's output feeds into graph's input
277 auto a = make_shared<op::Parameter>(element::i32, shape);
278 auto b = make_shared<op::Parameter>(element::i32, shape);
279 auto iconst0 = construct_constant_node(0);
280 auto iconst1 = construct_constant_node(1);
281 auto graph = b + (iconst0 + ((a + iconst0) * iconst1));
282 run_passes(pass_manager, graph, {a, b});
283 ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
284 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
285 ASSERT_TRUE(a->get_output_target_inputs(0).count(
286 graph->input(1))); // a's output feeds into graph's input
290 auto a = make_shared<op::Parameter>(element::i32, shape);
291 auto b = make_shared<op::Parameter>(element::i32, shape);
292 auto iconst1 = construct_constant_node(1);
293 auto graph = b + (iconst1 * (iconst1 * (iconst1 * (iconst1 * a))));
294 run_passes(pass_manager, graph, {a, b});
295 ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
296 ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
297 ASSERT_TRUE(a->get_output_target_inputs(0).count(
298 graph->input(1))); // a's output feeds into graph's input
302 TEST(pattern, matcher)
305 auto a = make_shared<op::Parameter>(element::i32, shape);
307 ASSERT_TRUE(n.match(a, a));
308 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
310 auto abs = make_shared<op::Abs>(a);
311 auto any = std::make_shared<pattern::op::Skip>(a);
312 ASSERT_TRUE(n.match(any, abs));
313 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
315 auto false_pred = [](std::shared_ptr<Node> /* no */) { return false; };
316 auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
317 ASSERT_TRUE(n.match(any_false, a));
318 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
320 auto pattern = std::make_shared<pattern::op::Label>(a);
321 ASSERT_TRUE(n.match(pattern, a));
322 ASSERT_EQ(n.get_pattern_map()[pattern], a);
323 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
325 auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
326 ASSERT_FALSE(n.match(pattern_false, a));
327 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
329 auto b = make_shared<op::Parameter>(element::i32, shape);
331 auto is_bea = [](std::shared_ptr<Node> node) -> bool {
332 return op::is_binary_elementwise_arithmetic(node);
334 auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
336 ASSERT_TRUE(n.match(bea, add_ab));
337 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
338 ASSERT_TRUE(n.match(bea, b + a));
340 auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
341 ASSERT_FALSE(n.match(bea_false, a + b));
343 auto add_abs_b = abs + b;
344 auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
345 ASSERT_TRUE(n.match(bea_any_of, add_abs_b));
347 auto add_b_abs = b + abs;
348 ASSERT_TRUE(n.match(bea_any_of, add_b_abs));
350 auto bea_any_of_label =
351 std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
352 ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
353 ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);
355 auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
356 auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
357 ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
358 ASSERT_EQ(n.get_pattern_map()[abs_label], abs);
360 auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
362 ASSERT_TRUE(n.match(bea_label, ab));
363 ASSERT_EQ(n.get_pattern_map()[bea_label], ab);
365 auto d = make_shared<op::Parameter>(element::i32, shape);
366 ASSERT_FALSE(n.match(d, b));
368 ASSERT_FALSE(n.match(abs + b, b + b));
369 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
371 auto add_absb = abs + b;
372 ASSERT_TRUE(n.match(any + b, add_absb));
373 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
375 ASSERT_TRUE(n.match(pattern + b, add_absb));
376 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
377 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
379 ASSERT_TRUE(n.match(b + pattern, add_absb));
380 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
381 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
383 auto c = make_shared<op::Parameter>(element::i32, shape);
384 auto mul_add_absb = c * (add_absb);
385 ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
386 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
387 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
389 ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); // nested any
390 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
391 ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); // permutations w/ any
392 auto mul_c_add_ab = c * add_ab;
393 ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b))); // nested any
394 ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); // permutations w/ any_false
395 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
397 auto iconst1_0 = construct_constant_node(1);
398 auto iconst1_1 = construct_constant_node(1);
399 ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); // different iconst
400 ASSERT_EQ(n.get_pattern_map()[pattern], a);
401 auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
402 auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
403 ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); // different iconst
407 auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
408 ASSERT_TRUE(n.match(label, add));
409 ASSERT_EQ(n.get_pattern_map()[label], add);
410 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
412 ASSERT_FALSE(n.match(label, a - b));
414 ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
415 ASSERT_EQ(n.get_pattern_map()[label], add);
417 // Correct argument order
418 ASSERT_FALSE(n.match(b - a, a - b));
419 auto aab = a * (a - b);
420 auto paab = pattern * (pattern - b);
421 ASSERT_TRUE(n.match(paab, aab));
422 auto aba = a * (b - a);
423 ASSERT_FALSE(n.match(paab, aba));
424 auto paba = pattern * (b - pattern);
425 ASSERT_FALSE(n.match(paba, aab));
428 auto label1 = std::make_shared<pattern::op::Label>(a);
429 auto tmp = label1 + b;
430 auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
431 auto sub_label1 = label1 - label2;
432 auto sub_add = a - add;
433 ASSERT_TRUE(n.match(sub_label1, sub_add));
434 ASSERT_EQ(n.get_pattern_map()[label1], a);
435 ASSERT_EQ(n.get_pattern_map()[label2], add);
436 ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
438 ASSERT_FALSE(n.match(sub_label1, add - a));
440 auto add_label1 = label1 + label2;
441 ASSERT_TRUE(n.match(add_label1, add + a));
442 ASSERT_EQ(n.get_pattern_map()[label1], a);
443 ASSERT_EQ(n.get_pattern_map()[label2], add);
446 ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
447 ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));
451 auto branch = std::make_shared<pattern::op::Branch>();
452 auto star = std::make_shared<pattern::op::Or>(
453 OutputVector{branch, std::make_shared<pattern::op::True>()});
454 auto pattern = star + star;
455 branch->set_destination(pattern);
456 ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
457 ASSERT_EQ(n.get_matched_nodes().size(), 4);
462 TestMatcher sm(Output<Node>{}, "TestMatcher", true);
463 // exact shape and type
464 auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
465 auto label_dynamic_shape =
466 make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
467 auto param = make_shared<op::Parameter>(element::f32, Shape{});
468 ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
470 auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
471 ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
473 auto label_dynamic_dimension =
474 make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
475 auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
476 ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
478 auto label_dynamic_type =
479 make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
480 ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
489 auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
490 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
491 auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
492 auto mean = std::make_shared<op::Divide>(sum_input1, N);
494 auto mean_graph = construct_mean_graph();
495 ASSERT_TRUE(n.match(mean_graph, mean));
496 ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
499 TEST(pattern, variance)
501 // construct variance
503 auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
504 auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
505 auto input_sq = std::make_shared<op::Multiply>(input, input);
506 auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
507 auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
508 auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
509 auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
510 auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
511 auto variance = std::make_shared<op::Divide>(xmu, N);
513 auto var_graph = construct_variance_graph();
514 ASSERT_TRUE(n.match(var_graph, variance));
515 ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
518 TEST(pattern, previous_matches)
520 using ngraph::pattern::Matcher;
522 Matcher::PatternMap previous_matches;
523 auto a = make_shared<op::Parameter>(element::i32, shape);
524 auto b = make_shared<op::Parameter>(element::i32, shape);
525 auto pattern = std::make_shared<pattern::op::Label>(b);
526 auto abs = make_shared<op::Abs>(a);
529 Matcher n(pattern + b);
530 ASSERT_TRUE(n.match(add, previous_matches));
531 ASSERT_EQ(n.get_pattern_map()[pattern], abs);
535 Matcher n(pattern + b);
536 previous_matches.insert(std::make_pair(pattern, a));
537 ASSERT_FALSE(n.match(add, previous_matches));
541 TEST(pattern, test_sort)
543 using ngraph::pattern::Matcher;
546 auto a = make_shared<op::Parameter>(element::i32, shape);
547 auto b = make_shared<op::Parameter>(element::i32, shape);
548 auto abs1 = make_shared<op::Abs>(a);
549 auto abs2 = make_shared<op::Abs>(b);
550 auto add = abs1 + abs2;
552 auto pa = make_shared<op::Parameter>(element::i32, shape);
553 auto pb = make_shared<op::Parameter>(element::i32, shape);
554 auto pabs1 = make_shared<op::Abs>(pa);
555 auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
556 auto pabs2 = make_shared<op::Abs>(b);
557 auto padd = pabs1_label + pabs2;
561 ASSERT_TRUE(n1.match(add));
562 auto r1 = n1.get_pattern_map()[pabs1_label];
563 ASSERT_TRUE(n1.match(add));
564 ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
568 TEST(pattern, recurrent_pattern)
570 using ngraph::pattern::RecurrentMatcher;
572 ngraph::pattern::Matcher::PatternMap previous_matches;
573 auto a = make_shared<op::Parameter>(element::i32, shape);
574 auto b = make_shared<op::Parameter>(element::i32, shape);
575 auto rpattern = std::make_shared<pattern::op::Label>(b);
576 auto iconst0 = construct_constant_node(0);
577 auto abs = make_shared<op::Abs>(a);
578 auto add1 = iconst0 + b;
579 auto add2 = iconst0 + add1;
580 auto add3 = iconst0 + add2;
581 auto padd = iconst0 + rpattern;
582 std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
583 RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
584 ASSERT_TRUE(rm.match(add3));
585 ASSERT_EQ(rm.get_number_of_bound_labels(), 3);
586 auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
587 ASSERT_EQ(recurrent_matches.at(0), add2);
588 ASSERT_EQ(recurrent_matches.at(1), add1);
589 ASSERT_EQ(recurrent_matches.at(2), b);
591 // Multiple labels in a reccuring pattern
592 auto iconst1 = construct_constant_node(1);
593 auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
594 auto add2_2 = iconst1 + add1;
595 auto add3_2 = iconst0 + add2_2;
596 auto padd2 = iconst_label + rpattern;
597 RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
598 ASSERT_TRUE(rm2.match(add3_2));
599 ASSERT_EQ(rm2.get_number_of_bound_labels(), 4);
600 recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
601 ASSERT_EQ(recurrent_matches.at(0), add2_2);
602 ASSERT_EQ(recurrent_matches.at(1), add1);
603 ASSERT_EQ(recurrent_matches.at(2), b);
604 auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
605 ASSERT_EQ(iconst_matches.at(0), iconst0);
606 ASSERT_EQ(iconst_matches.at(1), iconst1);
607 ASSERT_EQ(iconst_matches.at(2), iconst0);
609 // Non-matching correlated labels
610 std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
611 correlated_matches.insert(iconst_label);
612 RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
613 ASSERT_TRUE(rm3.match(add3_2));
614 ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
615 iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
616 ASSERT_EQ(iconst_matches.size(), 1);
617 ASSERT_EQ(iconst_matches.at(0), iconst0);
619 // Matching correlated labels and
620 // testing if RecurrentMatcher can be reused for different nodes
621 ASSERT_TRUE(rm3.match(add3));
622 ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
623 recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
624 ASSERT_EQ(recurrent_matches.at(0), add2);
625 ASSERT_EQ(recurrent_matches.at(1), add1);
626 ASSERT_EQ(recurrent_matches.at(2), b);
627 iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
628 ASSERT_EQ(iconst_matches.at(0), iconst0);
629 ASSERT_EQ(iconst_matches.at(1), iconst0);
630 ASSERT_EQ(iconst_matches.at(2), iconst0);
633 class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
636 void construct_recurrent_add()
639 auto iconst0 = construct_constant_node(0);
641 std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
642 auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
643 auto padd = iconst_label + rpattern;
645 auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
646 NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
647 << rm.get_match_root()->get_name();
649 auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
651 auto is_iconst_zero = [](std::shared_ptr<Node> n) {
652 bool result = ngraph::is_zero(n);
653 NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
654 return ngraph::is_zero(n);
657 bool are_all_iconst_zeros =
658 std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);
660 if (!are_all_iconst_zeros)
665 auto number_of_adds = rm.get_number_of_recurrent_matches();
666 // replace the topmost add with the seed (i.e. the first parameter to add)
667 // matches are added in reverse order (i.e. the first match is the topmost node)
668 auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
669 NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
671 ngraph::replace_node(rm.get_match_root(), arg);
675 std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
676 auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
677 NGRAPH_SUPPRESS_DEPRECATED_START
678 this->add_matcher(rm, callback);
679 NGRAPH_SUPPRESS_DEPRECATED_END
682 TestRecurrentGraphRewrite()
683 : RecurrentGraphRewrite()
685 construct_recurrent_add();
689 TEST(pattern, recurrent_graph_rewrite)
692 pass::Manager pass_manager;
693 pass_manager.register_pass<TestRecurrentGraphRewrite>();
696 auto a = make_shared<op::Parameter>(element::i32, shape);
697 auto iconst0 = construct_constant_node(0);
698 auto add_a1 = a + iconst0;
699 auto add_a2 = add_a1 + iconst0;
700 auto add_a3 = add_a2 + iconst0;
701 auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);
703 auto b = make_shared<op::Parameter>(element::i32, shape);
704 auto add_b1 = b + iconst0;
705 auto add_b2 = add_b1 + iconst0;
706 auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);
708 auto graph = abs_add_a3 * abs_add_b2;
710 auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
711 pass_manager.run_passes(f);
713 auto left_abs = graph->input_value(0).get_node_shared_ptr();
714 auto add_a = left_abs->input_value(0).get_node_shared_ptr();
717 auto right_abs = graph->input_value(1).get_node_shared_ptr();
718 auto add_b = right_abs->input_value(0).get_node_shared_ptr();
723 TEST(pattern, label_on_skip)
726 auto a = make_shared<op::Parameter>(element::i32, shape);
727 auto b = make_shared<op::Parameter>(element::i32, Shape{});
728 auto iconst = ngraph::make_zero(element::i32, Shape{});
729 auto label = std::make_shared<pattern::op::Label>(iconst);
731 std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
733 auto bcst_pred = [](std::shared_ptr<Node> n) {
734 return as_type_ptr<op::Broadcast>(n) != nullptr;
737 auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
738 auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
739 auto matcher = std::make_shared<pattern::Matcher>(
740 std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
742 auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
743 auto mul = a * const_broadcast;
744 auto mul_scalar = b * iconst;
745 ASSERT_TRUE(matcher->match(mul));
746 ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
747 ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
748 ASSERT_EQ(matcher->get_pattern_map()[label], a);
749 ASSERT_TRUE(matcher->match(mul_scalar));
750 ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
751 ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
752 ASSERT_EQ(matcher->get_pattern_map()[label], b);
755 TEST(pattern, is_contained_match)
758 auto a = make_shared<op::Parameter>(element::i32, shape);
759 auto absn = make_shared<op::Abs>(a);
762 auto label_a = std::make_shared<pattern::op::Label>(a);
763 auto label_abs = make_shared<op::Abs>(a);
764 ASSERT_TRUE(n.match(label_abs, absn));
765 auto result_absn = make_shared<op::Result>(absn);
766 ASSERT_TRUE(n.is_contained_match());
768 auto absn2 = make_shared<op::Abs>(absn);
769 auto result_absn2 = make_shared<op::Result>(absn2);
770 auto label_abs2 = make_shared<op::Abs>(label_abs);
771 ASSERT_TRUE(n.match(label_abs2, absn2));
772 ASSERT_FALSE(n.is_contained_match());
775 TEST(pattern, wrap_type)
777 auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
778 auto b = make_shared<op::Abs>(a);
779 auto c = make_shared<op::Relu>(a);
780 auto mul1 = make_shared<op::v1::Multiply>(a, op::Constant::create(element::f32, Shape{}, {1}));
781 auto mul2 = make_shared<op::v1::Multiply>(op::Constant::create(element::f32, Shape{}, {1}), a);
784 auto m = pattern::wrap_type<op::Abs>();
785 auto matcher = std::make_shared<pattern::Matcher>(m, "AbsMatcher");
786 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
787 ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
788 ASSERT_EQ(matcher->get_matched_nodes()[0], b);
789 ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
790 ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
793 auto m1 = pattern::wrap_type<op::Parameter>();
794 auto m2 = pattern::wrap_type<op::Abs>({m1});
795 auto matcher = std::make_shared<pattern::Matcher>(m2, "ParamAbsMatcher");
796 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
797 ASSERT_EQ(matcher->get_matched_nodes().size(), 2);
798 ASSERT_EQ(matcher->get_pattern_map().count(m1), 1);
799 ASSERT_EQ(matcher->get_pattern_map().count(m2), 1);
800 ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
803 auto m1 = pattern::wrap_type<op::v1::Multiply>(
804 {pattern::any_input(), pattern::wrap_type<op::Constant>()});
805 auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
806 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
807 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
810 auto m1 = pattern::wrap_type<op::v1::Multiply>(
811 {pattern::wrap_type<op::Constant>(), pattern::any_input()});
812 auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
813 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
814 ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));