8030710211ae51bbf90aa09b3d0acd168e5ce8a1
[platform/upstream/dldt.git] / ngraph / core / src / pass / constant_folding_arithmetic_reduction.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "constant_folding.hpp"
18 #include "ngraph/log.hpp"
19 #include "ngraph/op/constant.hpp"
20 #include "ngraph/op/max.hpp"
21 #include "ngraph/op/min.hpp"
22 #include "ngraph/op/reduce_mean.hpp"
23 #include "ngraph/op/reduce_prod.hpp"
24 #include "ngraph/op/reduce_sum.hpp"
25 #include "ngraph/op/sum.hpp"
26 #include "ngraph/runtime/reference/max.hpp"
27 #include "ngraph/runtime/reference/mean.hpp"
28 #include "ngraph/runtime/reference/min.hpp"
29 #include "ngraph/runtime/reference/product.hpp"
30 #include "ngraph/runtime/reference/sum.hpp"
31
32 NGRAPH_SUPPRESS_DEPRECATED_START
33
34 using namespace std;
35 using namespace ngraph;
36
37 template <typename T>
38 static shared_ptr<op::Constant>
39     fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
40                                               shared_ptr<Node> reduction_node)
41 {
42     const Shape& out_shape = reduction_node->get_shape();
43     runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
44     T* data_ptr = buffer.get_ptr<T>();
45
46     if (auto max = as_type_ptr<op::Max>(reduction_node))
47     {
48         runtime::reference::max<T>(constant->get_data_ptr<T>(),
49                                    data_ptr,
50                                    constant->get_output_shape(0),
51                                    max->get_reduction_axes(),
52                                    false);
53     }
54     else if (auto reduce_max = as_type_ptr<op::v1::ReduceMax>(reduction_node))
55     {
56         runtime::reference::max<T>(constant->get_data_ptr<T>(),
57                                    data_ptr,
58                                    constant->get_output_shape(0),
59                                    reduce_max->get_reduction_axes(),
60                                    reduce_max->get_keep_dims());
61     }
62     else if (auto min = as_type_ptr<op::Min>(reduction_node))
63     {
64         runtime::reference::min<T>(constant->get_data_ptr<T>(),
65                                    data_ptr,
66                                    constant->get_output_shape(0),
67                                    min->get_reduction_axes());
68     }
69     else if (auto reduce_min = as_type_ptr<op::v1::ReduceMin>(reduction_node))
70     {
71         runtime::reference::min<T>(constant->get_data_ptr<T>(),
72                                    data_ptr,
73                                    constant->get_output_shape(0),
74                                    reduce_min->get_reduction_axes());
75     }
76     else if (auto reduce_prod = as_type_ptr<op::v1::ReduceProd>(reduction_node))
77     {
78         runtime::reference::product<T>(constant->get_data_ptr<T>(),
79                                        data_ptr,
80                                        constant->get_output_shape(0),
81                                        reduce_prod->get_reduction_axes(),
82                                        reduce_prod->get_keep_dims());
83     }
84     else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
85     {
86         runtime::reference::sum<T>(constant->get_data_ptr<T>(),
87                                    data_ptr,
88                                    constant->get_output_shape(0),
89                                    sum->get_reduction_axes(),
90                                    false);
91     }
92     else if (auto reduce_sum = as_type_ptr<op::v1::ReduceSum>(reduction_node))
93     {
94         runtime::reference::sum<T>(constant->get_data_ptr<T>(),
95                                    data_ptr,
96                                    constant->get_output_shape(0),
97                                    reduce_sum->get_reduction_axes(),
98                                    reduce_sum->get_keep_dims());
99     }
100     else if (auto reduce_mean = as_type_ptr<op::v1::ReduceMean>(reduction_node))
101     {
102         runtime::reference::mean<T>(constant->get_data_ptr<T>(),
103                                     data_ptr,
104                                     constant->get_output_shape(0),
105                                     reduce_mean->get_reduction_axes(),
106                                     reduce_mean->get_keep_dims());
107     }
108     else
109     {
110         NGRAPH_CHECK(false,
111                      "Internal nGraph error: Ops handled in "
112                      "fold_constant_arithmetic_reduction_helper must be consistent with those "
113                      "matched in construct_constant_arithmetic_reduction");
114     }
115
116     return make_shared<op::Constant>(
117         reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
118 }
119
120 static shared_ptr<op::Constant>
121     fold_constant_arithmetic_reduction(shared_ptr<op::Constant> constant,
122                                        shared_ptr<Node> reduction_node)
123 {
124     auto& input_element_type = constant->get_output_element_type(0);
125
126     switch (input_element_type)
127     {
128     case element::Type_t::undefined:
129         NGRAPH_CHECK(false,
130                      "Encountered 'undefined' element type in fold_constant_arithmetic_reduction");
131         break;
132     case element::Type_t::dynamic:
133         NGRAPH_CHECK(false,
134                      "Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
135         break;
136     case element::Type_t::u1:
137         NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_arithmetic_reduction");
138         break;
139     case element::Type_t::boolean:
140         return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
141     case element::Type_t::bf16:
142         return fold_constant_arithmetic_reduction_helper<bfloat16>(constant, reduction_node);
143     case element::Type_t::f16:
144         return fold_constant_arithmetic_reduction_helper<float16>(constant, reduction_node);
145     case element::Type_t::f32:
146         return fold_constant_arithmetic_reduction_helper<float>(constant, reduction_node);
147     case element::Type_t::f64:
148         return fold_constant_arithmetic_reduction_helper<double>(constant, reduction_node);
149     case element::Type_t::i8:
150         return fold_constant_arithmetic_reduction_helper<int8_t>(constant, reduction_node);
151     case element::Type_t::i16:
152         return fold_constant_arithmetic_reduction_helper<int16_t>(constant, reduction_node);
153     case element::Type_t::i32:
154         return fold_constant_arithmetic_reduction_helper<int32_t>(constant, reduction_node);
155     case element::Type_t::i64:
156         return fold_constant_arithmetic_reduction_helper<int64_t>(constant, reduction_node);
157     case element::Type_t::u8:
158         return fold_constant_arithmetic_reduction_helper<uint8_t>(constant, reduction_node);
159     case element::Type_t::u16:
160         return fold_constant_arithmetic_reduction_helper<uint16_t>(constant, reduction_node);
161     case element::Type_t::u32:
162         return fold_constant_arithmetic_reduction_helper<uint32_t>(constant, reduction_node);
163     case element::Type_t::u64:
164         return fold_constant_arithmetic_reduction_helper<uint64_t>(constant, reduction_node);
165     }
166
167     NGRAPH_UNREACHABLE("Unexpected switch case");
168 }
169
170 void pass::ConstantFolding::construct_constant_arithmetic_reduction()
171 {
172     auto constant_data_label = make_shared<pattern::op::Label>(
173         element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
174     auto constant_axes_label =
175         make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
176     auto is_supported_reduction = [](std::shared_ptr<Node> n) {
177         return (pattern::has_class<op::Max>()(n) || pattern::has_class<op::Min>()(n) ||
178                 pattern::has_class<op::Sum>()(n) || pattern::has_class<op::v1::ReduceMax>()(n) ||
179                 pattern::has_class<op::v1::ReduceMin>()(n) ||
180                 pattern::has_class<op::v1::ReduceProd>()(n) ||
181                 pattern::has_class<op::v1::ReduceSum>()(n) ||
182                 pattern::has_class<op::v1::ReduceMean>()(n));
183     };
184     auto reduction =
185         std::make_shared<pattern::op::Any>(element::i32,
186                                            Shape{2},
187                                            is_supported_reduction,
188                                            NodeVector{constant_data_label, constant_axes_label});
189
190     auto constant_arithmetic_reduction_callback = [this, constant_data_label](pattern::Matcher& m) {
191         NGRAPH_DEBUG << "In callback for constant_arithmetic_reduction_callback against node = "
192                      << m.get_match_root()->get_name();
193
194         auto pattern_map = m.get_pattern_map();
195
196         auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
197         auto reduction_match = m.get_match_root();
198
199         if (cf_is_disabled(reduction_match))
200             return false;
201
202         NGRAPH_CHECK(revalidate_and_ensure_static(reduction_match));
203
204         auto const_node = fold_constant_arithmetic_reduction(constant_match, reduction_match);
205         const_node->set_friendly_name(reduction_match->get_friendly_name());
206         replace_node(reduction_match, const_node);
207         copy_runtime_info_to_target_inputs(reduction_match, const_node);
208
209         return true;
210     };
211
212     auto arithmetic_reduction_matcher =
213         make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
214     NGRAPH_SUPPRESS_DEPRECATED_START
215     this->add_matcher(arithmetic_reduction_matcher,
216                       constant_arithmetic_reduction_callback,
217                       PassProperty::CHANGE_DYNAMIC_STATE);
218     NGRAPH_SUPPRESS_DEPRECATED_END
219 }