1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "dyn_elimination.hpp"
20 #include "ngraph/op/broadcast.hpp"
21 #include "ngraph/op/range.hpp"
22 #include "ngraph/op/replace_slice.hpp"
23 #include "ngraph/op/reshape.hpp"
24 #include "ngraph/op/slice.hpp"
25 #include "ngraph/op/transpose.hpp"
26 #include "ngraph/pattern/matcher.hpp"
27 #include "ngraph/pattern/op/label.hpp"
28 #include "ngraph/runtime/reference/range.hpp"
29 #include "ngraph/slice_plan.hpp"
31 NGRAPH_SUPPRESS_DEPRECATED_START
34 using namespace ngraph;
36 pass::DynElimination::DynElimination()
39 construct_transpose();
43 void pass::DynElimination::construct_transpose()
45 auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
47 make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
49 auto transpose = make_shared<op::Transpose>(data_arg_label, perm_arg_label);
51 auto transpose_callback = [data_arg_label, perm_arg_label](pattern::Matcher& m) {
52 auto pattern_map = m.get_pattern_map();
54 auto data_arg = pattern_map[data_arg_label];
55 auto perm_arg = static_pointer_cast<op::Constant>(pattern_map[perm_arg_label]);
57 // TODO(amprocte): Can't handle the case where data shape is dynamic, because static
58 // Reshape requries the exact output shape to be declared. See if we can come up with a
60 if (data_arg->get_output_partial_shape(0).is_dynamic())
65 auto& data_shape = data_arg->get_output_shape(0);
67 NGRAPH_CHECK(perm_arg->get_output_partial_shape(0).rank().compatible(1));
68 NGRAPH_CHECK(perm_arg->get_output_element_type(0).compatible(element::i64));
70 if (perm_arg->get_output_element_type(0).is_dynamic() ||
71 perm_arg->get_output_partial_shape(0).is_dynamic())
76 auto perm = perm_arg->get_axis_vector_val();
78 auto output_shape = ngraph::apply_permutation(data_shape, perm);
80 auto replacement = std::make_shared<op::Reshape>(data_arg, perm, output_shape);
82 replace_node(m.get_match_root(), replacement);
86 auto transpose_matcher = make_shared<pattern::Matcher>(transpose, "DynElimination.Transpose");
87 NGRAPH_SUPPRESS_DEPRECATED_START
88 add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
89 NGRAPH_SUPPRESS_DEPRECATED_END
93 std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
95 const std::shared_ptr<op::Constant>& start_arg,
96 const std::shared_ptr<op::Constant>& step_arg)
98 std::vector<T> elements(shape_size(shape));
99 std::vector<T> start_vec = start_arg->get_vector<T>();
100 std::vector<T> step_vec = step_arg->get_vector<T>();
102 NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
104 runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape, elements.data());
106 return make_shared<op::Constant>(et, shape, elements);
109 void pass::DynElimination::construct_range()
111 auto start_arg_label =
112 make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
113 auto stop_arg_label =
114 make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
115 auto step_arg_label =
116 make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
118 auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
120 auto range_callback = [start_arg_label, stop_arg_label, step_arg_label](pattern::Matcher& m) {
121 auto pattern_map = m.get_pattern_map();
123 auto start_arg = static_pointer_cast<op::Constant>(pattern_map[start_arg_label]);
124 auto step_arg = static_pointer_cast<op::Constant>(pattern_map[step_arg_label]);
125 auto range_node = static_pointer_cast<op::Range>(m.get_match_root());
127 NGRAPH_CHECK(start_arg->get_output_partial_shape(0).rank().compatible(0) &&
128 step_arg->get_output_partial_shape(0).rank().compatible(0));
130 auto et = range_node->get_output_element_type(0);
131 auto shape = range_node->get_output_shape(0);
133 std::shared_ptr<op::Constant> replacement;
135 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
136 #pragma GCC diagnostic push
137 #pragma GCC diagnostic error "-Wswitch"
138 #pragma GCC diagnostic error "-Wswitch-enum"
142 case element::Type_t::bf16:
143 replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
145 case element::Type_t::f16:
146 replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
148 case element::Type_t::f32:
149 replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
151 case element::Type_t::f64:
152 replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
154 case element::Type_t::i8:
155 replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
157 case element::Type_t::i16:
158 replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
160 case element::Type_t::i32:
161 replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
163 case element::Type_t::i64:
164 replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
166 case element::Type_t::u8:
167 replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
169 case element::Type_t::u16:
170 replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
172 case element::Type_t::u32:
173 replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
175 case element::Type_t::u64:
176 replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
178 case element::Type_t::u1:
179 case element::Type_t::undefined:
180 case element::Type_t::dynamic:
181 case element::Type_t::boolean:
182 NGRAPH_CHECK(false, "Internal nGraph error: unsupported element type: ", et);
185 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
186 #pragma GCC diagnostic pop
189 replace_node(range_node, replacement);
193 auto range_matcher = make_shared<pattern::Matcher>(range_pat, "DynElimination.Range");
194 NGRAPH_SUPPRESS_DEPRECATED_START
195 add_matcher(range_matcher, range_callback, all_pass_property_off);
196 NGRAPH_SUPPRESS_DEPRECATED_END