1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "dyn_elimination.hpp"
20 #include "ngraph/op/broadcast.hpp"
21 #include "ngraph/op/range.hpp"
22 #include "ngraph/op/replace_slice.hpp"
23 #include "ngraph/op/reshape.hpp"
24 #include "ngraph/op/reverse.hpp"
25 #include "ngraph/op/slice.hpp"
26 #include "ngraph/op/transpose.hpp"
27 #include "ngraph/pattern/matcher.hpp"
28 #include "ngraph/pattern/op/label.hpp"
29 #include "ngraph/runtime/reference/range.hpp"
30 #include "ngraph/slice_plan.hpp"
33 using namespace ngraph;
35 pass::DynElimination::DynElimination()
38 construct_transpose();
42 void pass::DynElimination::construct_transpose()
44 auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
46 make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
48 auto transpose = make_shared<op::Transpose>(data_arg_label, perm_arg_label);
50 auto transpose_callback = [data_arg_label, perm_arg_label](pattern::Matcher& m) {
51 auto pattern_map = m.get_pattern_map();
53 auto data_arg = pattern_map[data_arg_label];
54 auto perm_arg = static_pointer_cast<op::Constant>(pattern_map[perm_arg_label]);
56 // TODO(amprocte): Can't handle the case where data shape is dynamic, because static
57 // Reshape requries the exact output shape to be declared. See if we can come up with a
59 if (data_arg->get_output_partial_shape(0).is_dynamic())
64 auto& data_shape = data_arg->get_output_shape(0);
66 NGRAPH_CHECK(perm_arg->get_output_partial_shape(0).rank().compatible(1));
67 NGRAPH_CHECK(perm_arg->get_output_element_type(0).compatible(element::i64));
69 if (perm_arg->get_output_element_type(0).is_dynamic() ||
70 perm_arg->get_output_partial_shape(0).is_dynamic())
75 auto perm = perm_arg->get_axis_vector_val();
77 auto output_shape = ngraph::apply_permutation(data_shape, perm);
79 auto replacement = std::make_shared<op::Reshape>(data_arg, perm, output_shape);
81 replace_node(m.get_match_root(), replacement);
85 auto transpose_matcher = make_shared<pattern::Matcher>(transpose, "DynElimination.Transpose");
86 NGRAPH_SUPPRESS_DEPRECATED_START
87 add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
88 NGRAPH_SUPPRESS_DEPRECATED_END
92 std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
94 const std::shared_ptr<op::Constant>& start_arg,
95 const std::shared_ptr<op::Constant>& step_arg)
97 std::vector<T> elements(shape_size(shape));
98 std::vector<T> start_vec = start_arg->get_vector<T>();
99 std::vector<T> step_vec = step_arg->get_vector<T>();
101 NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
103 runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape, elements.data());
105 return make_shared<op::Constant>(et, shape, elements);
108 void pass::DynElimination::construct_range()
110 auto start_arg_label =
111 make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
112 auto stop_arg_label =
113 make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
114 auto step_arg_label =
115 make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
117 auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
119 auto range_callback = [start_arg_label, stop_arg_label, step_arg_label](pattern::Matcher& m) {
120 auto pattern_map = m.get_pattern_map();
122 auto start_arg = static_pointer_cast<op::Constant>(pattern_map[start_arg_label]);
123 auto step_arg = static_pointer_cast<op::Constant>(pattern_map[step_arg_label]);
124 auto range_node = static_pointer_cast<op::Range>(m.get_match_root());
126 NGRAPH_CHECK(start_arg->get_output_partial_shape(0).rank().compatible(0) &&
127 step_arg->get_output_partial_shape(0).rank().compatible(0));
129 auto et = range_node->get_output_element_type(0);
130 auto shape = range_node->get_output_shape(0);
132 std::shared_ptr<op::Constant> replacement;
134 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
135 #pragma GCC diagnostic push
136 #pragma GCC diagnostic error "-Wswitch"
137 #pragma GCC diagnostic error "-Wswitch-enum"
141 case element::Type_t::bf16:
142 replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
144 case element::Type_t::f16:
145 replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
147 case element::Type_t::f32:
148 replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
150 case element::Type_t::f64:
151 replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
153 case element::Type_t::i8:
154 replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
156 case element::Type_t::i16:
157 replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
159 case element::Type_t::i32:
160 replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
162 case element::Type_t::i64:
163 replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
165 case element::Type_t::u8:
166 replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
168 case element::Type_t::u16:
169 replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
171 case element::Type_t::u32:
172 replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
174 case element::Type_t::u64:
175 replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
177 case element::Type_t::u1:
178 case element::Type_t::undefined:
179 case element::Type_t::dynamic:
180 case element::Type_t::boolean:
181 NGRAPH_CHECK(false, "Internal nGraph error: unsupported element type: ", et);
184 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
185 #pragma GCC diagnostic pop
188 replace_node(range_node, replacement);
192 auto range_matcher = make_shared<pattern::Matcher>(range_pat, "DynElimination.Range");
193 NGRAPH_SUPPRESS_DEPRECATED_START
194 add_matcher(range_matcher, range_callback, all_pass_property_off);
195 NGRAPH_SUPPRESS_DEPRECATED_END