Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Mean.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Mean.h"
19
20 #include "kernels/Utils.h"
21
22 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
23
24 #include <stdexcept>
25
26 namespace luci_interpreter
27 {
28 namespace kernels
29 {
30
31 static void resolveAxes(const int *axes_data, int num_axes, tflite::MeanParams *params)
32 {
33   params->axis_count = num_axes;
34   for (int i = 0; i < num_axes; ++i)
35   {
36     params->axis[i] = static_cast<int16>(axes_data[i]);
37   }
38   for (int i = num_axes; i < 4; ++i)
39   {
40     params->axis[i] = 1;
41   }
42 }
43
44 // Returns the number of axes that will be reduced. Removes duplicates.
45 static int getAxisReductionCount(const int *axes_data, int num_axes, int input_num_dims)
46 {
47   int reduction_count = num_axes;
48   for (int i = 0; i < num_axes; ++i)
49   {
50     int current = axes_data[i] >= 0 ? axes_data[i] : axes_data[i] + input_num_dims;
51     assert(current >= 0 && current < input_num_dims);
52     for (int j = 0; j < i; j++)
53     {
54       int previous = axes_data[j] >= 0 ? axes_data[j] : axes_data[j] + input_num_dims;
55       // This checks for duplicate axis
56       if (current == previous)
57       {
58         --reduction_count;
59         break;
60       }
61     }
62   }
63   return reduction_count;
64 }
65
66 static Shape getOutputShape(const Shape &input_shape, const int *axes_data, int num_axes,
67                             bool keep_dims)
68 {
69   int input_num_dims = input_shape.num_dims();
70   if (input_num_dims == 0)
71   {
72     return Shape(0);
73   }
74
75   if (keep_dims)
76   {
77     Shape output_shape(input_num_dims);
78     for (int idx = 0; idx < input_num_dims; ++idx)
79     {
80       bool is_axis = false;
81       for (int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
82       {
83         if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
84         {
85           is_axis = true;
86           break;
87         }
88       }
89       if (is_axis)
90       {
91         output_shape.dim(idx) = 1;
92       }
93       else
94       {
95         output_shape.dim(idx) = input_shape.dim(idx);
96       }
97     }
98     return output_shape;
99   }
100   else
101   {
102     int num_reduce_axes = getAxisReductionCount(axes_data, num_axes, input_num_dims);
103     Shape output_shape(input_num_dims - num_reduce_axes);
104     int num_skip_axes = 0;
105     for (int idx = 0; idx < input_num_dims; ++idx)
106     {
107       bool is_axis = false;
108       for (int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
109       {
110         if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
111         {
112           ++num_skip_axes;
113           is_axis = true;
114           break;
115         }
116       }
117       if (!is_axis)
118       {
119         output_shape.dim(idx - num_skip_axes) = input_shape.dim(idx);
120       }
121     }
122     return output_shape;
123   }
124 }
125
126 Mean::Mean(const Tensor *input, const Tensor *axes, Tensor *output, const ReducerParams &params)
127     : KernelWithParams<ReducerParams>({input, axes}, {output}, params)
128 {
129 }
130
131 void Mean::configure()
132 {
133   assert(input()->element_type() == output()->element_type());
134   assert(axes()->element_type() == DataType::S32);
135   const Shape &input_shape = input()->shape();
136   int input_num_dims = input_shape.num_dims();
137
138   const auto *axes_data = getTensorData<int32_t>(axes());
139   int num_axes = axes()->shape().num_elements();
140   assert(num_axes <= 4);
141
142   Shape output_shape = getOutputShape(input_shape, axes_data, num_axes, _params.keep_dims);
143   output()->resize(output_shape);
144
145   tflite::MeanParams params{};
146   resolveAxes(axes_data, num_axes, &params);
147   const bool need_temporaries =
148       !(_params.keep_dims && input_num_dims == 4 && params.axis_count == 2 &&
149         ((params.axis[0] == 1 && params.axis[1] == 2) ||
150          (params.axis[0] == 2 && params.axis[1] == 1)));
151   if (need_temporaries)
152   {
153     _temp_index =
154         std::make_unique<Tensor>(DataType::S32, Shape(input_num_dims), AffineQuantization{}, "");
155     _resolved_axes =
156         std::make_unique<Tensor>(DataType::S32, Shape(num_axes), AffineQuantization{}, "");
157     _temp_sum = std::make_unique<Tensor>(input()->element_type(), output()->shape(),
158                                          AffineQuantization{}, "");
159   }
160 }
161
162 void Mean::execute() const
163 {
164   switch (input()->element_type())
165   {
166     case DataType::FLOAT32:
167       evalFloat();
168       break;
169     case DataType::U8:
170       evalQuantized();
171       break;
172     default:
173       throw std::runtime_error("Unsupported type.");
174   }
175 }
176
177 void Mean::evalFloat() const
178 {
179   const Shape &input_shape = input()->shape();
180   int input_num_dims = input_shape.num_dims();
181   const auto *axes_data = getTensorData<int32_t>(axes());
182   int num_axes = axes()->shape().num_elements();
183
184   tflite::MeanParams params{};
185   resolveAxes(axes_data, num_axes, &params);
186
187   // Defer to specialized implementation for 4D Mean across axes 1 & 2.
188   if (_params.keep_dims && input_num_dims == 4 && params.axis_count == 2 &&
189       ((params.axis[0] == 1 && params.axis[1] == 2) ||
190        (params.axis[0] == 2 && params.axis[1] == 1)))
191   {
192     tflite::reference_ops::Mean(params, getTensorShape(input()), getTensorData<float>(input()),
193                                 getTensorShape(output()), getTensorData<float>(output()));
194   }
195   else
196   {
197     tflite::reference_ops::Mean(
198         getTensorData<float>(input()), getTensorShape(input()).DimsData(),
199         input()->shape().num_dims(), getTensorData<float>(output()),
200         getTensorShape(output()).DimsData(), output()->shape().num_dims(), axes_data, num_axes,
201         _params.keep_dims, getTensorData<int>(_temp_index.get()),
202         getTensorData<int>(_resolved_axes.get()), getTensorData<float>(_temp_sum.get()));
203   }
204 }
205
206 void Mean::evalQuantized() const
207 {
208   const Shape &input_shape = input()->shape();
209   int input_num_dims = input_shape.num_dims();
210   const auto *axes_data = getTensorData<int32_t>(axes());
211   int num_axes = axes()->shape().num_elements();
212
213   tflite::MeanParams params{};
214   resolveAxes(axes_data, num_axes, &params);
215
216   // Defer to specialized implementation for 4D Mean across axes 1 & 2.
217   if (_params.keep_dims && input_num_dims == 4 && params.axis_count == 2 &&
218       ((params.axis[0] == 1 && params.axis[1] == 2) ||
219        (params.axis[0] == 2 && params.axis[1] == 1)))
220   {
221     tflite::reference_ops::Mean(params, getTensorShape(input()), getTensorData<uint8_t>(input()),
222                                 input()->zero_point(), input()->scale(), getTensorShape(output()),
223                                 getTensorData<uint8_t>(output()), output()->zero_point(),
224                                 output()->scale());
225   }
226   else if (input()->zero_point() == output()->zero_point() && input()->scale() == output()->scale())
227   {
228     tflite::reference_ops::Mean(
229         getTensorData<uint8_t>(input()), getTensorShape(input()).DimsData(),
230         input()->shape().num_dims(), getTensorData<uint8_t>(output()),
231         getTensorShape(output()).DimsData(), output()->shape().num_dims(), axes_data, num_axes,
232         _params.keep_dims, getTensorData<int>(_temp_index.get()),
233         getTensorData<int>(_resolved_axes.get()), getTensorData<int>(_temp_sum.get()));
234   }
235   else
236   {
237     tflite::reference_ops::QuantizedMeanOrSum<>(
238         getTensorData<uint8_t>(input()), input()->zero_point(), input()->scale(),
239         getTensorShape(input()).DimsData(), input()->shape().num_dims(),
240         getTensorData<uint8_t>(output()), output()->zero_point(), output()->scale(),
241         getTensorShape(output()).DimsData(), output()->shape().num_dims(), axes_data, num_axes,
242         _params.keep_dims, getTensorData<int>(_temp_index.get()),
243         getTensorData<int>(_resolved_axes.get()), getTensorData<int>(_temp_sum.get()),
244         /*compute_sum=*/false);
245   }
246 }
247
248 } // namespace kernels
249 } // namespace luci_interpreter