Remove obsoleted Min, Max operators (#2832)
[platform/upstream/dldt.git] / ngraph / core / src / pass / constant_folding_arithmetic_reduction.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "constant_folding.hpp"
18 #include "ngraph/log.hpp"
19 #include "ngraph/op/constant.hpp"
20 #include "ngraph/op/max.hpp"
21 #include "ngraph/op/min.hpp"
22 #include "ngraph/op/reduce_mean.hpp"
23 #include "ngraph/op/reduce_prod.hpp"
24 #include "ngraph/op/reduce_sum.hpp"
25 #include "ngraph/op/sum.hpp"
26 #include "ngraph/runtime/reference/max.hpp"
27 #include "ngraph/runtime/reference/mean.hpp"
28 #include "ngraph/runtime/reference/min.hpp"
29 #include "ngraph/runtime/reference/product.hpp"
30 #include "ngraph/runtime/reference/sum.hpp"
31
32 NGRAPH_SUPPRESS_DEPRECATED_START
33
34 using namespace std;
35 using namespace ngraph;
36
37 template <typename T>
38 static shared_ptr<op::Constant>
39     fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
40                                               shared_ptr<Node> reduction_node)
41 {
42     const Shape& out_shape = reduction_node->get_shape();
43     runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
44     T* data_ptr = buffer.get_ptr<T>();
45
46     if (auto reduce_max = as_type_ptr<op::v1::ReduceMax>(reduction_node))
47     {
48         runtime::reference::max<T>(constant->get_data_ptr<T>(),
49                                    data_ptr,
50                                    constant->get_output_shape(0),
51                                    reduce_max->get_reduction_axes(),
52                                    reduce_max->get_keep_dims());
53     }
54     else if (auto reduce_min = as_type_ptr<op::v1::ReduceMin>(reduction_node))
55     {
56         runtime::reference::min<T>(constant->get_data_ptr<T>(),
57                                    data_ptr,
58                                    constant->get_output_shape(0),
59                                    reduce_min->get_reduction_axes());
60     }
61     else if (auto reduce_prod = as_type_ptr<op::v1::ReduceProd>(reduction_node))
62     {
63         runtime::reference::product<T>(constant->get_data_ptr<T>(),
64                                        data_ptr,
65                                        constant->get_output_shape(0),
66                                        reduce_prod->get_reduction_axes(),
67                                        reduce_prod->get_keep_dims());
68     }
69     else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
70     {
71         runtime::reference::sum<T>(constant->get_data_ptr<T>(),
72                                    data_ptr,
73                                    constant->get_output_shape(0),
74                                    sum->get_reduction_axes(),
75                                    false);
76     }
77     else if (auto reduce_sum = as_type_ptr<op::v1::ReduceSum>(reduction_node))
78     {
79         runtime::reference::sum<T>(constant->get_data_ptr<T>(),
80                                    data_ptr,
81                                    constant->get_output_shape(0),
82                                    reduce_sum->get_reduction_axes(),
83                                    reduce_sum->get_keep_dims());
84     }
85     else if (auto reduce_mean = as_type_ptr<op::v1::ReduceMean>(reduction_node))
86     {
87         runtime::reference::mean<T>(constant->get_data_ptr<T>(),
88                                     data_ptr,
89                                     constant->get_output_shape(0),
90                                     reduce_mean->get_reduction_axes(),
91                                     reduce_mean->get_keep_dims());
92     }
93     else
94     {
95         NGRAPH_CHECK(false,
96                      "Internal nGraph error: Ops handled in "
97                      "fold_constant_arithmetic_reduction_helper must be consistent with those "
98                      "matched in construct_constant_arithmetic_reduction");
99     }
100
101     return make_shared<op::Constant>(
102         reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
103 }
104
105 static shared_ptr<op::Constant>
106     fold_constant_arithmetic_reduction(shared_ptr<op::Constant> constant,
107                                        shared_ptr<Node> reduction_node)
108 {
109     auto& input_element_type = constant->get_output_element_type(0);
110
111     switch (input_element_type)
112     {
113     case element::Type_t::undefined:
114         NGRAPH_CHECK(false,
115                      "Encountered 'undefined' element type in fold_constant_arithmetic_reduction");
116         break;
117     case element::Type_t::dynamic:
118         NGRAPH_CHECK(false,
119                      "Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
120         break;
121     case element::Type_t::u1:
122         NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_arithmetic_reduction");
123         break;
124     case element::Type_t::boolean:
125         return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
126     case element::Type_t::bf16:
127         return fold_constant_arithmetic_reduction_helper<bfloat16>(constant, reduction_node);
128     case element::Type_t::f16:
129         return fold_constant_arithmetic_reduction_helper<float16>(constant, reduction_node);
130     case element::Type_t::f32:
131         return fold_constant_arithmetic_reduction_helper<float>(constant, reduction_node);
132     case element::Type_t::f64:
133         return fold_constant_arithmetic_reduction_helper<double>(constant, reduction_node);
134     case element::Type_t::i8:
135         return fold_constant_arithmetic_reduction_helper<int8_t>(constant, reduction_node);
136     case element::Type_t::i16:
137         return fold_constant_arithmetic_reduction_helper<int16_t>(constant, reduction_node);
138     case element::Type_t::i32:
139         return fold_constant_arithmetic_reduction_helper<int32_t>(constant, reduction_node);
140     case element::Type_t::i64:
141         return fold_constant_arithmetic_reduction_helper<int64_t>(constant, reduction_node);
142     case element::Type_t::u8:
143         return fold_constant_arithmetic_reduction_helper<uint8_t>(constant, reduction_node);
144     case element::Type_t::u16:
145         return fold_constant_arithmetic_reduction_helper<uint16_t>(constant, reduction_node);
146     case element::Type_t::u32:
147         return fold_constant_arithmetic_reduction_helper<uint32_t>(constant, reduction_node);
148     case element::Type_t::u64:
149         return fold_constant_arithmetic_reduction_helper<uint64_t>(constant, reduction_node);
150     }
151
152     NGRAPH_UNREACHABLE("Unexpected switch case");
153 }
154
155 void pass::ConstantFolding::construct_constant_arithmetic_reduction()
156 {
157     auto constant_data_label = make_shared<pattern::op::Label>(
158         element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
159     auto constant_axes_label =
160         make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
161     auto is_supported_reduction = [](std::shared_ptr<Node> n) {
162         return (pattern::has_class<op::Sum>()(n) || pattern::has_class<op::v1::ReduceMax>()(n) ||
163                 pattern::has_class<op::v1::ReduceMin>()(n) ||
164                 pattern::has_class<op::v1::ReduceProd>()(n) ||
165                 pattern::has_class<op::v1::ReduceSum>()(n) ||
166                 pattern::has_class<op::v1::ReduceMean>()(n));
167     };
168     auto reduction =
169         std::make_shared<pattern::op::Any>(element::i32,
170                                            Shape{2},
171                                            is_supported_reduction,
172                                            NodeVector{constant_data_label, constant_axes_label});
173
174     auto constant_arithmetic_reduction_callback = [this, constant_data_label](pattern::Matcher& m) {
175         NGRAPH_DEBUG << "In callback for constant_arithmetic_reduction_callback against node = "
176                      << m.get_match_root()->get_name();
177
178         auto pattern_map = m.get_pattern_map();
179
180         auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
181         auto reduction_match = m.get_match_root();
182
183         if (cf_is_disabled(reduction_match))
184             return false;
185
186         NGRAPH_CHECK(revalidate_and_ensure_static(reduction_match));
187
188         auto const_node = fold_constant_arithmetic_reduction(constant_match, reduction_match);
189         const_node->set_friendly_name(reduction_match->get_friendly_name());
190         replace_node(reduction_match, const_node);
191         copy_runtime_info_to_target_inputs(reduction_match, const_node);
192
193         return true;
194     };
195
196     auto arithmetic_reduction_matcher =
197         make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
198     NGRAPH_SUPPRESS_DEPRECATED_START
199     this->add_matcher(arithmetic_reduction_matcher,
200                       constant_arithmetic_reduction_callback,
201                       PassProperty::CHANGE_DYNAMIC_STATE);
202     NGRAPH_SUPPRESS_DEPRECATED_END
203 }