1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
19 #include "ngraph/ngraph.hpp"
20 #include "ngraph/pass/constant_folding.hpp"
21 #include "ngraph/pass/dyn_elimination.hpp"
22 #include "ngraph/pass/manager.hpp"
23 #include "pass/opset0_downgrade.hpp"
24 #include "util/all_close_f.hpp"
25 #include "util/test_tools.hpp"
27 using namespace ngraph;
30 TEST(dyn_elimination, transpose)
32 Shape shape_in{2, 4, 6, 8};
33 auto param = make_shared<op::Parameter>(element::boolean, shape_in);
36 make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{2, 3, 1, 0});
38 auto transpose = make_shared<op::Transpose>(param, constant_perm);
40 auto f = make_shared<Function>(transpose, ParameterVector{param});
42 pass::Manager pass_manager;
43 pass_manager.register_pass<pass::DynElimination>();
44 pass_manager.run_passes(f);
46 ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
47 ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
50 as_type_ptr<op::Reshape>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
51 ASSERT_TRUE(new_reshape);
53 ASSERT_EQ(new_reshape->get_input_order(), (AxisVector{2, 3, 1, 0}));
54 ASSERT_EQ(new_reshape->get_output_shape(0), (Shape{6, 8, 4, 2}));
55 ASSERT_EQ(new_reshape->get_output_element_type(0), element::boolean);
58 // For now, we can't handle the case where the input has dynamic shapes,
59 // because the classic Reshape op demands a Shape. Probably won't be able to
60 // deal with this until/unless we make a "StaticTranspose". Just make sure
61 // we don't crash or mangle the graph.
62 TEST(dyn_elimination, transpose_dyn_shape)
64 PartialShape shape_in{2, 4, Dimension::dynamic(), 8};
66 auto param = make_shared<op::Parameter>(element::boolean, shape_in);
69 make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{2, 3, 1, 0});
71 auto transpose = make_shared<op::Transpose>(param, constant_perm);
73 auto f = make_shared<Function>(transpose, ParameterVector{param});
75 pass::Manager pass_manager;
76 pass_manager.register_pass<pass::DynElimination>();
77 pass_manager.run_passes(f);
79 ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 1);
80 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
83 as_type_ptr<op::Transpose>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
84 ASSERT_TRUE(new_transpose);
86 ASSERT_EQ(new_transpose->get_output_element_type(0), element::boolean);
87 ASSERT_TRUE(new_transpose->get_output_partial_shape(0).relaxes(
88 PartialShape{Dimension::dynamic(), 8, 4, 2}));
91 TEST(dyn_elimination, range)
93 auto constant_start = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{0});
94 auto constant_stop = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{5});
95 auto constant_step = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{2});
97 auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
99 ASSERT_EQ(range->get_element_type(), element::i64);
100 ASSERT_EQ(range->get_shape(), (Shape{3}));
102 auto f = make_shared<Function>(range, ParameterVector{});
104 pass::Manager pass_manager;
105 pass_manager.register_pass<pass::DynElimination>();
106 pass_manager.run_passes(f);
108 ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
109 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
112 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
114 ASSERT_NE(replacement, nullptr);
115 ASSERT_EQ(replacement->get_element_type(), element::i64);
116 ASSERT_EQ(replacement->get_shape(), (Shape{3}));
118 auto vals = replacement->get_vector<int64_t>();
120 ASSERT_EQ(vals, (vector<int64_t>{0, 2, 4}));
123 TEST(dyn_elimination, range_f64)
125 auto constant_start = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{-0.5});
126 auto constant_stop = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{2});
127 auto constant_step = make_shared<op::Constant>(element::f64, Shape{}, vector<double>{0.25});
129 auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
131 ASSERT_EQ(range->get_element_type(), element::f64);
132 ASSERT_EQ(range->get_shape(), (Shape{10}));
134 auto f = make_shared<Function>(range, ParameterVector{});
136 pass::Manager pass_manager;
137 pass_manager.register_pass<pass::DynElimination>();
138 pass_manager.run_passes(f);
140 ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
141 ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
144 as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
146 ASSERT_NE(replacement, nullptr);
147 ASSERT_EQ(replacement->get_element_type(), element::f64);
148 ASSERT_EQ(replacement->get_shape(), (Shape{10}));
150 auto vals = replacement->get_vector<double>();
152 ASSERT_TRUE(test::all_close_f(
153 vals, vector<double>{-0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75}));