1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "constant_folding.hpp"
18 #include "ngraph/log.hpp"
19 #include "ngraph/op/constant.hpp"
20 #include "ngraph/op/max.hpp"
21 #include "ngraph/op/min.hpp"
22 #include "ngraph/op/reduce_mean.hpp"
23 #include "ngraph/op/reduce_prod.hpp"
24 #include "ngraph/op/reduce_sum.hpp"
25 #include "ngraph/op/sum.hpp"
26 #include "ngraph/runtime/reference/max.hpp"
27 #include "ngraph/runtime/reference/mean.hpp"
28 #include "ngraph/runtime/reference/min.hpp"
29 #include "ngraph/runtime/reference/product.hpp"
30 #include "ngraph/runtime/reference/sum.hpp"
32 NGRAPH_SUPPRESS_DEPRECATED_START
35 using namespace ngraph;
38 static shared_ptr<op::Constant>
39 fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
40 shared_ptr<Node> reduction_node)
42 const Shape& out_shape = reduction_node->get_shape();
43 runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
44 T* data_ptr = buffer.get_ptr<T>();
46 if (auto reduce_max = as_type_ptr<op::v1::ReduceMax>(reduction_node))
48 runtime::reference::max<T>(constant->get_data_ptr<T>(),
50 constant->get_output_shape(0),
51 reduce_max->get_reduction_axes(),
52 reduce_max->get_keep_dims());
54 else if (auto reduce_min = as_type_ptr<op::v1::ReduceMin>(reduction_node))
56 runtime::reference::min<T>(constant->get_data_ptr<T>(),
58 constant->get_output_shape(0),
59 reduce_min->get_reduction_axes());
61 else if (auto reduce_prod = as_type_ptr<op::v1::ReduceProd>(reduction_node))
63 runtime::reference::product<T>(constant->get_data_ptr<T>(),
65 constant->get_output_shape(0),
66 reduce_prod->get_reduction_axes(),
67 reduce_prod->get_keep_dims());
69 else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
71 runtime::reference::sum<T>(constant->get_data_ptr<T>(),
73 constant->get_output_shape(0),
74 sum->get_reduction_axes(),
77 else if (auto reduce_sum = as_type_ptr<op::v1::ReduceSum>(reduction_node))
79 runtime::reference::sum<T>(constant->get_data_ptr<T>(),
81 constant->get_output_shape(0),
82 reduce_sum->get_reduction_axes(),
83 reduce_sum->get_keep_dims());
85 else if (auto reduce_mean = as_type_ptr<op::v1::ReduceMean>(reduction_node))
87 runtime::reference::mean<T>(constant->get_data_ptr<T>(),
89 constant->get_output_shape(0),
90 reduce_mean->get_reduction_axes(),
91 reduce_mean->get_keep_dims());
96 "Internal nGraph error: Ops handled in "
97 "fold_constant_arithmetic_reduction_helper must be consistent with those "
98 "matched in construct_constant_arithmetic_reduction");
101 return make_shared<op::Constant>(
102 reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
105 static shared_ptr<op::Constant>
106 fold_constant_arithmetic_reduction(shared_ptr<op::Constant> constant,
107 shared_ptr<Node> reduction_node)
109 auto& input_element_type = constant->get_output_element_type(0);
111 switch (input_element_type)
113 case element::Type_t::undefined:
115 "Encountered 'undefined' element type in fold_constant_arithmetic_reduction");
117 case element::Type_t::dynamic:
119 "Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
121 case element::Type_t::u1:
122 NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_arithmetic_reduction");
124 case element::Type_t::boolean:
125 return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
126 case element::Type_t::bf16:
127 return fold_constant_arithmetic_reduction_helper<bfloat16>(constant, reduction_node);
128 case element::Type_t::f16:
129 return fold_constant_arithmetic_reduction_helper<float16>(constant, reduction_node);
130 case element::Type_t::f32:
131 return fold_constant_arithmetic_reduction_helper<float>(constant, reduction_node);
132 case element::Type_t::f64:
133 return fold_constant_arithmetic_reduction_helper<double>(constant, reduction_node);
134 case element::Type_t::i8:
135 return fold_constant_arithmetic_reduction_helper<int8_t>(constant, reduction_node);
136 case element::Type_t::i16:
137 return fold_constant_arithmetic_reduction_helper<int16_t>(constant, reduction_node);
138 case element::Type_t::i32:
139 return fold_constant_arithmetic_reduction_helper<int32_t>(constant, reduction_node);
140 case element::Type_t::i64:
141 return fold_constant_arithmetic_reduction_helper<int64_t>(constant, reduction_node);
142 case element::Type_t::u8:
143 return fold_constant_arithmetic_reduction_helper<uint8_t>(constant, reduction_node);
144 case element::Type_t::u16:
145 return fold_constant_arithmetic_reduction_helper<uint16_t>(constant, reduction_node);
146 case element::Type_t::u32:
147 return fold_constant_arithmetic_reduction_helper<uint32_t>(constant, reduction_node);
148 case element::Type_t::u64:
149 return fold_constant_arithmetic_reduction_helper<uint64_t>(constant, reduction_node);
152 NGRAPH_UNREACHABLE("Unexpected switch case");
155 void pass::ConstantFolding::construct_constant_arithmetic_reduction()
157 auto constant_data_label = make_shared<pattern::op::Label>(
158 element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
159 auto constant_axes_label =
160 make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
161 auto is_supported_reduction = [](std::shared_ptr<Node> n) {
162 return (pattern::has_class<op::Sum>()(n) || pattern::has_class<op::v1::ReduceMax>()(n) ||
163 pattern::has_class<op::v1::ReduceMin>()(n) ||
164 pattern::has_class<op::v1::ReduceProd>()(n) ||
165 pattern::has_class<op::v1::ReduceSum>()(n) ||
166 pattern::has_class<op::v1::ReduceMean>()(n));
169 std::make_shared<pattern::op::Any>(element::i32,
171 is_supported_reduction,
172 NodeVector{constant_data_label, constant_axes_label});
174 auto constant_arithmetic_reduction_callback = [this, constant_data_label](pattern::Matcher& m) {
175 NGRAPH_DEBUG << "In callback for constant_arithmetic_reduction_callback against node = "
176 << m.get_match_root()->get_name();
178 auto pattern_map = m.get_pattern_map();
180 auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
181 auto reduction_match = m.get_match_root();
183 if (cf_is_disabled(reduction_match))
186 NGRAPH_CHECK(revalidate_and_ensure_static(reduction_match));
188 auto const_node = fold_constant_arithmetic_reduction(constant_match, reduction_match);
189 const_node->set_friendly_name(reduction_match->get_friendly_name());
190 replace_node(reduction_match, const_node);
191 copy_runtime_info_to_target_inputs(reduction_match, const_node);
196 auto arithmetic_reduction_matcher =
197 make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
198 NGRAPH_SUPPRESS_DEPRECATED_START
199 this->add_matcher(arithmetic_reduction_matcher,
200 constant_arithmetic_reduction_callback,
201 PassProperty::CHANGE_DYNAMIC_STATE);
202 NGRAPH_SUPPRESS_DEPRECATED_END