2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Sum.h"
20 #include "kernels/Utils.h"
22 #include <tensorflow/lite/kernels/internal/reference/reduce.h>
26 namespace luci_interpreter
31 // Returns the number of axes that will be reduced. Removes duplicates.
32 static int getAxisReductionCount(const int32_t *axes_data, int num_axes, int input_num_dims)
34 int reduction_count = num_axes;
35 for (int i = 0; i < num_axes; ++i)
37 int current = axes_data[i] >= 0 ? axes_data[i] : axes_data[i] + input_num_dims;
38 assert(current >= 0 && current < input_num_dims);
39 for (int j = 0; j < i; j++)
41 int previous = axes_data[j] >= 0 ? axes_data[j] : axes_data[j] + input_num_dims;
42 // This checks for duplicate axis
43 if (current == previous)
50 return reduction_count;
53 static Shape getOutputShape(const Shape &input_shape, const int32_t *axes_data, int num_axes,
56 int input_num_dims = input_shape.num_dims();
57 if (input_num_dims == 0)
64 Shape output_shape(input_num_dims);
65 for (int idx = 0; idx < input_num_dims; ++idx)
68 for (int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
70 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
78 output_shape.dim(idx) = 1;
82 output_shape.dim(idx) = input_shape.dim(idx);
89 int num_reduce_axes = getAxisReductionCount(axes_data, num_axes, input_num_dims);
90 Shape output_shape(input_num_dims - num_reduce_axes);
91 int num_skip_axes = 0;
92 for (int idx = 0; idx < input_num_dims; ++idx)
95 for (int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
97 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
106 output_shape.dim(idx - num_skip_axes) = input_shape.dim(idx);
113 Sum::Sum(const Tensor *input, const Tensor *axes, Tensor *output, Tensor *temp_index,
114 Tensor *resolved_axes, const ReducerParams ¶ms)
115 : KernelWithParams<ReducerParams>({input, axes}, {output, temp_index, resolved_axes}, params)
119 void Sum::configure()
121 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
122 LUCI_INTERPRETER_CHECK(axes()->element_type() == DataType::S32);
124 const Shape &input_shape = input()->shape();
125 int input_num_dims = input_shape.num_dims();
127 const auto *axes_data = getTensorData<int32_t>(axes());
128 int num_axes = axes()->shape().num_elements();
129 LUCI_INTERPRETER_CHECK(num_axes <= 4);
131 // We compute shapes of outputs in configure, assuming that outputs have
133 // TODO Support dynamic shape
134 Shape output_shape = getOutputShape(input_shape, axes_data, num_axes, _params.keep_dims);
135 output()->resize(output_shape);
137 auto temp_index = getOutputTensors()[1];
138 auto resolved_axes = getOutputTensors()[2];
140 temp_index->resize(Shape(input_num_dims));
141 resolved_axes->resize(Shape(num_axes));
144 void Sum::execute() const
146 switch (input()->element_type())
148 case DataType::FLOAT32:
152 throw std::runtime_error("Unsupported type.");
156 void Sum::evalFloat() const
158 const auto *axes_data = getTensorData<int32_t>(axes());
159 int num_axes = axes()->shape().num_elements();
161 auto temp_index = getOutputTensors()[1];
162 auto resolved_axes = getOutputTensors()[2];
164 int num_resolved_axis = 0;
165 LUCI_INTERPRETER_CHECK(
166 tflite::reference_ops::ResolveAxis(input()->shape().num_dims(), axes_data, num_axes,
167 getTensorData<int>(resolved_axes), &num_resolved_axis));
169 float init_value = 0.0;
170 tflite::reference_ops::ReduceGeneric<float>(
171 getTensorData<float>(input()), getTensorShape(input()).DimsData(), input()->shape().num_dims(),
172 getTensorData<float>(output()), getTensorShape(output()).DimsData(),
173 output()->shape().num_dims(), axes_data, num_axes, _params.keep_dims,
174 getTensorData<int>(temp_index), getTensorData<int>(resolved_axes), init_value,
175 [](const float current, const float in) -> float { return current + in; });
178 } // namespace kernels
179 } // namespace luci_interpreter