2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
20 #include "kernels/Utils.h"
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h>
23 #include <tensorflow/lite/kernels/internal/reference/pooling.h>
25 namespace luci_interpreter
33 void evalFloat(const circle::Tensor *input, const circle::Tensor *output,
34 const circle::Pool2DOptions *options, BaseRuntimeGraph *runtime_graph)
36 const int32_t input_height = Tensor::dim(input, 1);
37 const int32_t input_width = Tensor::dim(input, 2);
39 const int32_t output_height = kernels::computeOutputSize(
40 luci_padding(options->padding()), input_height, options->filter_height(), options->stride_h());
41 const int32_t output_width = kernels::computeOutputSize(
42 luci_padding(options->padding()), input_width, options->filter_width(), options->stride_w());
44 const auto padding_height = kernels::computePadding(options->stride_h(), 1, input_height,
45 options->filter_height(), output_height);
46 const auto padding_width = kernels::computePadding(options->stride_w(), 1, input_width,
47 options->filter_width(), output_width);
49 const auto *input_data = runtime_graph->getDataByTensor(input);
50 auto *output_data = runtime_graph->getDataByTensor(output);
52 float activation_min{};
53 float activation_max{};
54 kernels::calculateActivationRange(luci_actfunc(options->fused_activation_function()),
55 &activation_min, &activation_max);
56 tflite::PoolParams params{};
57 params.padding_values.height = padding_height;
58 params.padding_values.width = padding_width;
59 params.stride_height = options->stride_h();
60 params.stride_width = options->stride_w();
61 params.filter_height = options->filter_height();
62 params.filter_width = options->filter_width();
63 params.float_activation_min = activation_min;
64 params.float_activation_max = activation_max;
66 tflite::reference_ops::MaxPool(
67 params, kernels::getTensorShape(input), kernels::getTensorData<float>(input_data),
68 kernels::getTensorShape(output), kernels::getTensorData<float>(output_data));
74 void evalQuantized(const circle::Tensor *input, const circle::Tensor *output,
75 const circle::Pool2DOptions *options, BaseRuntimeGraph *runtime_graph)
77 int32_t activation_min{};
78 int32_t activation_max{};
79 kernels::calculateActivationRangeQuantized(luci_actfunc(options->fused_activation_function()),
80 output, &activation_min, &activation_max);
83 const int32_t input_height = Tensor::dim(input, 1);
84 const int32_t input_width = Tensor::dim(input, 2);
86 const int32_t output_height = kernels::computeOutputSize(
87 luci_padding(options->padding()), input_height, options->filter_height(), options->stride_h());
88 const int32_t output_width = kernels::computeOutputSize(
89 luci_padding(options->padding()), input_width, options->filter_width(), options->stride_w());
91 const auto padding_height = kernels::computePadding(options->stride_h(), 1, input_height,
92 options->filter_height(), output_height);
93 const auto padding_width = kernels::computePadding(options->stride_w(), 1, input_width,
94 options->filter_width(), output_width);
96 tflite::PoolParams params{};
97 params.padding_values.height = padding_height;
98 params.padding_values.width = padding_width;
99 params.stride_height = options->stride_h();
100 params.stride_width = options->stride_w();
101 params.filter_height = options->filter_height();
102 params.filter_width = options->filter_width();
103 params.quantized_activation_min = activation_min;
104 params.quantized_activation_max = activation_max;
106 const auto *input_data = runtime_graph->getDataByTensor(input);
107 auto *output_data = runtime_graph->getDataByTensor(output);
109 tflite::reference_ops::MaxPool(
110 params, kernels::getTensorShape(input), kernels::getTensorData<uint8_t>(input_data),
111 kernels::getTensorShape(output), kernels::getTensorData<uint8_t>(output_data));
114 void evalSInt16(const circle::Tensor *input, const circle::Tensor *output,
115 const circle::Pool2DOptions *options, BaseRuntimeGraph *runtime_graph)
117 int32_t activation_min{};
118 int32_t activation_max{};
119 kernels::calculateActivationRangeQuantized(luci_actfunc(options->fused_activation_function()),
120 output, &activation_min, &activation_max);
123 const int32_t input_height = Tensor::dim(input, 1);
124 const int32_t input_width = Tensor::dim(input, 2);
126 const int32_t output_height = kernels::computeOutputSize(
127 luci_padding(options->padding()), input_height, options->filter_height(), options->stride_h());
128 const int32_t output_width = kernels::computeOutputSize(
129 luci_padding(options->padding()), input_width, options->filter_width(), options->stride_w());
131 const auto padding_height = kernels::computePadding(options->stride_h(), 1, input_height,
132 options->filter_height(), output_height);
133 const auto padding_width = kernels::computePadding(options->stride_w(), 1, input_width,
134 options->filter_width(), output_width);
136 tflite::PoolParams params{};
137 params.padding_values.height = padding_height;
138 params.padding_values.width = padding_width;
139 params.stride_height = options->stride_h();
140 params.stride_width = options->stride_w();
141 params.filter_height = options->filter_height();
142 params.filter_width = options->filter_width();
143 params.quantized_activation_min = activation_min;
144 params.quantized_activation_max = activation_max;
146 const auto *input_data = runtime_graph->getDataByTensor(input);
147 auto *output_data = runtime_graph->getDataByTensor(output);
149 tflite::reference_integer_ops::MaxPool(
150 params, kernels::getTensorShape(input), kernels::getTensorData<int16_t>(input_data),
151 kernels::getTensorShape(output), kernels::getTensorData<int16_t>(output_data));
158 void configure_kernel_CircleMaxPool2D(const circle::Operator *cur_op,
159 BaseRuntimeGraph *runtime_graph)
161 const auto input_index = cur_op->inputs()->operator[](0);
162 const auto output_index = cur_op->outputs()->operator[](0);
164 assert(input_index != -1);
165 assert(output_index != -1);
167 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
168 const auto output = runtime_graph->getCircleTensorByIndex(output_index);
170 LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == Tensor::element_type(output));
171 assert(Tensor::num_dims(input) == 4);
174 if (Tensor::element_type(input) == DataType::U8)
176 LUCI_INTERPRETER_CHECK(std::abs(Tensor::scale(output) - Tensor::scale(input)) <= 1.0e-6);
177 LUCI_INTERPRETER_CHECK(Tensor::zero_point(output) == Tensor::zero_point(input));
179 else if (Tensor::element_type(input) == DataType::S16)
181 LUCI_INTERPRETER_CHECK(std::abs(Tensor::scale(output) - Tensor::scale(input)) <= 1.0e-6);
182 LUCI_INTERPRETER_CHECK(Tensor::zero_point(input) == 0 && Tensor::zero_point(output) == 0);
187 void execute_kernel_CircleMaxPool2D(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
190 const auto input_index = cur_op->inputs()->operator[](0);
191 const auto output_index = cur_op->outputs()->operator[](0);
193 assert(input_index != -1);
194 assert(output_index != -1);
196 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
197 auto output = runtime_graph->getCircleTensorByIndex(output_index);
199 const auto *options = cur_op->builtin_options_as_Pool2DOptions();
201 switch (Tensor::element_type(input))
204 case DataType::FLOAT32:
205 evalFloat(input, output, options, runtime_graph);
210 evalQuantized(input, output, options, runtime_graph);
213 evalSInt16(input, output, options, runtime_graph);
217 assert(false && "Unsupported type.");
221 } // namespace luci_interpreter