2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Mean.h"
20 #include "kernels/Utils.h"
22 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
26 namespace luci_interpreter
31 static void resolveAxes(const int *axes_data, int num_axes, tflite::MeanParams *params)
33 params->axis_count = num_axes;
34 for (int i = 0; i < num_axes; ++i)
36 params->axis[i] = static_cast<int16>(axes_data[i]);
38 for (int i = num_axes; i < 4; ++i)
44 // Returns the number of axes that will be reduced. Removes duplicates.
45 static int getAxisReductionCount(const int *axes_data, int num_axes, int input_num_dims)
47 int reduction_count = num_axes;
48 for (int i = 0; i < num_axes; ++i)
50 int current = axes_data[i] >= 0 ? axes_data[i] : axes_data[i] + input_num_dims;
51 assert(current >= 0 && current < input_num_dims);
52 for (int j = 0; j < i; j++)
54 int previous = axes_data[j] >= 0 ? axes_data[j] : axes_data[j] + input_num_dims;
55 // This checks for duplicate axis
56 if (current == previous)
63 return reduction_count;
66 static Shape getOutputShape(const Shape &input_shape, const int *axes_data, int num_axes,
69 int input_num_dims = input_shape.num_dims();
70 if (input_num_dims == 0)
77 Shape output_shape(input_num_dims);
78 for (int idx = 0; idx < input_num_dims; ++idx)
81 for (int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
83 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
91 output_shape.dim(idx) = 1;
95 output_shape.dim(idx) = input_shape.dim(idx);
102 int num_reduce_axes = getAxisReductionCount(axes_data, num_axes, input_num_dims);
103 Shape output_shape(input_num_dims - num_reduce_axes);
104 int num_skip_axes = 0;
105 for (int idx = 0; idx < input_num_dims; ++idx)
107 bool is_axis = false;
108 for (int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
110 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
119 output_shape.dim(idx - num_skip_axes) = input_shape.dim(idx);
126 Mean::Mean(const Tensor *input, const Tensor *axes, Tensor *output, const ReducerParams ¶ms)
127 : KernelWithParams<ReducerParams>({input, axes}, {output}, params)
131 void Mean::configure()
133 assert(input()->element_type() == output()->element_type());
134 assert(axes()->element_type() == DataType::S32);
135 const Shape &input_shape = input()->shape();
136 int input_num_dims = input_shape.num_dims();
138 const auto *axes_data = getTensorData<int32_t>(axes());
139 int num_axes = axes()->shape().num_elements();
140 assert(num_axes <= 4);
142 Shape output_shape = getOutputShape(input_shape, axes_data, num_axes, _params.keep_dims);
143 output()->resize(output_shape);
145 tflite::MeanParams params{};
146 resolveAxes(axes_data, num_axes, ¶ms);
147 const bool need_temporaries =
148 !(_params.keep_dims && input_num_dims == 4 && params.axis_count == 2 &&
149 ((params.axis[0] == 1 && params.axis[1] == 2) ||
150 (params.axis[0] == 2 && params.axis[1] == 1)));
151 if (need_temporaries)
154 std::make_unique<Tensor>(DataType::S32, Shape(input_num_dims), AffineQuantization{}, "");
156 std::make_unique<Tensor>(DataType::S32, Shape(num_axes), AffineQuantization{}, "");
157 _temp_sum = std::make_unique<Tensor>(input()->element_type(), output()->shape(),
158 AffineQuantization{}, "");
162 void Mean::execute() const
164 switch (input()->element_type())
166 case DataType::FLOAT32:
173 throw std::runtime_error("Unsupported type.");
177 void Mean::evalFloat() const
179 const Shape &input_shape = input()->shape();
180 int input_num_dims = input_shape.num_dims();
181 const auto *axes_data = getTensorData<int32_t>(axes());
182 int num_axes = axes()->shape().num_elements();
184 tflite::MeanParams params{};
185 resolveAxes(axes_data, num_axes, ¶ms);
187 // Defer to specialized implementation for 4D Mean across axes 1 & 2.
188 if (_params.keep_dims && input_num_dims == 4 && params.axis_count == 2 &&
189 ((params.axis[0] == 1 && params.axis[1] == 2) ||
190 (params.axis[0] == 2 && params.axis[1] == 1)))
192 tflite::reference_ops::Mean(params, getTensorShape(input()), getTensorData<float>(input()),
193 getTensorShape(output()), getTensorData<float>(output()));
197 tflite::reference_ops::Mean(
198 getTensorData<float>(input()), getTensorShape(input()).DimsData(),
199 input()->shape().num_dims(), getTensorData<float>(output()),
200 getTensorShape(output()).DimsData(), output()->shape().num_dims(), axes_data, num_axes,
201 _params.keep_dims, getTensorData<int>(_temp_index.get()),
202 getTensorData<int>(_resolved_axes.get()), getTensorData<float>(_temp_sum.get()));
206 void Mean::evalQuantized() const
208 const Shape &input_shape = input()->shape();
209 int input_num_dims = input_shape.num_dims();
210 const auto *axes_data = getTensorData<int32_t>(axes());
211 int num_axes = axes()->shape().num_elements();
213 tflite::MeanParams params{};
214 resolveAxes(axes_data, num_axes, ¶ms);
216 // Defer to specialized implementation for 4D Mean across axes 1 & 2.
217 if (_params.keep_dims && input_num_dims == 4 && params.axis_count == 2 &&
218 ((params.axis[0] == 1 && params.axis[1] == 2) ||
219 (params.axis[0] == 2 && params.axis[1] == 1)))
221 tflite::reference_ops::Mean(params, getTensorShape(input()), getTensorData<uint8_t>(input()),
222 input()->zero_point(), input()->scale(), getTensorShape(output()),
223 getTensorData<uint8_t>(output()), output()->zero_point(),
226 else if (input()->zero_point() == output()->zero_point() && input()->scale() == output()->scale())
228 tflite::reference_ops::Mean(
229 getTensorData<uint8_t>(input()), getTensorShape(input()).DimsData(),
230 input()->shape().num_dims(), getTensorData<uint8_t>(output()),
231 getTensorShape(output()).DimsData(), output()->shape().num_dims(), axes_data, num_axes,
232 _params.keep_dims, getTensorData<int>(_temp_index.get()),
233 getTensorData<int>(_resolved_axes.get()), getTensorData<int>(_temp_sum.get()));
237 tflite::reference_ops::QuantizedMeanOrSum<>(
238 getTensorData<uint8_t>(input()), input()->zero_point(), input()->scale(),
239 getTensorShape(input()).DimsData(), input()->shape().num_dims(),
240 getTensorData<uint8_t>(output()), output()->zero_point(), output()->scale(),
241 getTensorShape(output()).DimsData(), output()->shape().num_dims(), axes_data, num_axes,
242 _params.keep_dims, getTensorData<int>(_temp_index.get()),
243 getTensorData<int>(_resolved_axes.get()), getTensorData<int>(_temp_sum.get()),
244 /*compute_sum=*/false);
248 } // namespace kernels
249 } // namespace luci_interpreter