1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "constant_folding.hpp"
18 #include "ngraph/log.hpp"
19 #include "ngraph/op/constant.hpp"
20 #include "ngraph/op/max.hpp"
21 #include "ngraph/op/min.hpp"
22 #include "ngraph/op/reduce_mean.hpp"
23 #include "ngraph/op/reduce_prod.hpp"
24 #include "ngraph/op/reduce_sum.hpp"
25 #include "ngraph/op/sum.hpp"
26 #include "ngraph/runtime/reference/max.hpp"
27 #include "ngraph/runtime/reference/mean.hpp"
28 #include "ngraph/runtime/reference/min.hpp"
29 #include "ngraph/runtime/reference/product.hpp"
30 #include "ngraph/runtime/reference/sum.hpp"
32 NGRAPH_SUPPRESS_DEPRECATED_START
35 using namespace ngraph;
38 static shared_ptr<op::Constant>
39 fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
40 shared_ptr<Node> reduction_node)
42 const Shape& out_shape = reduction_node->get_shape();
43 runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
44 T* data_ptr = buffer.get_ptr<T>();
46 if (auto max = as_type_ptr<op::Max>(reduction_node))
48 runtime::reference::max<T>(constant->get_data_ptr<T>(),
50 constant->get_output_shape(0),
51 max->get_reduction_axes(),
54 else if (auto reduce_max = as_type_ptr<op::v1::ReduceMax>(reduction_node))
56 runtime::reference::max<T>(constant->get_data_ptr<T>(),
58 constant->get_output_shape(0),
59 reduce_max->get_reduction_axes(),
60 reduce_max->get_keep_dims());
62 else if (auto min = as_type_ptr<op::Min>(reduction_node))
64 runtime::reference::min<T>(constant->get_data_ptr<T>(),
66 constant->get_output_shape(0),
67 min->get_reduction_axes());
69 else if (auto reduce_min = as_type_ptr<op::v1::ReduceMin>(reduction_node))
71 runtime::reference::min<T>(constant->get_data_ptr<T>(),
73 constant->get_output_shape(0),
74 reduce_min->get_reduction_axes());
76 else if (auto reduce_prod = as_type_ptr<op::v1::ReduceProd>(reduction_node))
78 runtime::reference::product<T>(constant->get_data_ptr<T>(),
80 constant->get_output_shape(0),
81 reduce_prod->get_reduction_axes(),
82 reduce_prod->get_keep_dims());
84 else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
86 runtime::reference::sum<T>(constant->get_data_ptr<T>(),
88 constant->get_output_shape(0),
89 sum->get_reduction_axes(),
92 else if (auto reduce_sum = as_type_ptr<op::v1::ReduceSum>(reduction_node))
94 runtime::reference::sum<T>(constant->get_data_ptr<T>(),
96 constant->get_output_shape(0),
97 reduce_sum->get_reduction_axes(),
98 reduce_sum->get_keep_dims());
100 else if (auto reduce_mean = as_type_ptr<op::v1::ReduceMean>(reduction_node))
102 runtime::reference::mean<T>(constant->get_data_ptr<T>(),
104 constant->get_output_shape(0),
105 reduce_mean->get_reduction_axes(),
106 reduce_mean->get_keep_dims());
111 "Internal nGraph error: Ops handled in "
112 "fold_constant_arithmetic_reduction_helper must be consistent with those "
113 "matched in construct_constant_arithmetic_reduction");
116 return make_shared<op::Constant>(
117 reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
120 static shared_ptr<op::Constant>
121 fold_constant_arithmetic_reduction(shared_ptr<op::Constant> constant,
122 shared_ptr<Node> reduction_node)
124 auto& input_element_type = constant->get_output_element_type(0);
126 switch (input_element_type)
128 case element::Type_t::undefined:
130 "Encountered 'undefined' element type in fold_constant_arithmetic_reduction");
132 case element::Type_t::dynamic:
134 "Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
136 case element::Type_t::u1:
137 NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_arithmetic_reduction");
139 case element::Type_t::boolean:
140 return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
141 case element::Type_t::bf16:
142 return fold_constant_arithmetic_reduction_helper<bfloat16>(constant, reduction_node);
143 case element::Type_t::f16:
144 return fold_constant_arithmetic_reduction_helper<float16>(constant, reduction_node);
145 case element::Type_t::f32:
146 return fold_constant_arithmetic_reduction_helper<float>(constant, reduction_node);
147 case element::Type_t::f64:
148 return fold_constant_arithmetic_reduction_helper<double>(constant, reduction_node);
149 case element::Type_t::i8:
150 return fold_constant_arithmetic_reduction_helper<int8_t>(constant, reduction_node);
151 case element::Type_t::i16:
152 return fold_constant_arithmetic_reduction_helper<int16_t>(constant, reduction_node);
153 case element::Type_t::i32:
154 return fold_constant_arithmetic_reduction_helper<int32_t>(constant, reduction_node);
155 case element::Type_t::i64:
156 return fold_constant_arithmetic_reduction_helper<int64_t>(constant, reduction_node);
157 case element::Type_t::u8:
158 return fold_constant_arithmetic_reduction_helper<uint8_t>(constant, reduction_node);
159 case element::Type_t::u16:
160 return fold_constant_arithmetic_reduction_helper<uint16_t>(constant, reduction_node);
161 case element::Type_t::u32:
162 return fold_constant_arithmetic_reduction_helper<uint32_t>(constant, reduction_node);
163 case element::Type_t::u64:
164 return fold_constant_arithmetic_reduction_helper<uint64_t>(constant, reduction_node);
167 NGRAPH_UNREACHABLE("Unexpected switch case");
170 void pass::ConstantFolding::construct_constant_arithmetic_reduction()
172 auto constant_data_label = make_shared<pattern::op::Label>(
173 element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
174 auto constant_axes_label =
175 make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
176 auto is_supported_reduction = [](std::shared_ptr<Node> n) {
177 return (pattern::has_class<op::Max>()(n) || pattern::has_class<op::Min>()(n) ||
178 pattern::has_class<op::Sum>()(n) || pattern::has_class<op::v1::ReduceMax>()(n) ||
179 pattern::has_class<op::v1::ReduceMin>()(n) ||
180 pattern::has_class<op::v1::ReduceProd>()(n) ||
181 pattern::has_class<op::v1::ReduceSum>()(n) ||
182 pattern::has_class<op::v1::ReduceMean>()(n));
185 std::make_shared<pattern::op::Any>(element::i32,
187 is_supported_reduction,
188 NodeVector{constant_data_label, constant_axes_label});
190 auto constant_arithmetic_reduction_callback = [this, constant_data_label](pattern::Matcher& m) {
191 NGRAPH_DEBUG << "In callback for constant_arithmetic_reduction_callback against node = "
192 << m.get_match_root()->get_name();
194 auto pattern_map = m.get_pattern_map();
196 auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
197 auto reduction_match = m.get_match_root();
199 if (cf_is_disabled(reduction_match))
202 NGRAPH_CHECK(revalidate_and_ensure_static(reduction_match));
204 auto const_node = fold_constant_arithmetic_reduction(constant_match, reduction_match);
205 const_node->set_friendly_name(reduction_match->get_friendly_name());
206 replace_node(reduction_match, const_node);
207 copy_runtime_info_to_target_inputs(reduction_match, const_node);
212 auto arithmetic_reduction_matcher =
213 make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
214 NGRAPH_SUPPRESS_DEPRECATED_START
215 this->add_matcher(arithmetic_reduction_matcher,
216 constant_arithmetic_reduction_callback,
217 PassProperty::CHANGE_DYNAMIC_STATE);
218 NGRAPH_SUPPRESS_DEPRECATED_END