2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/AveragePool2D.h"
19 #include "kernels/Utils.h"
21 #include <tensorflow/lite/kernels/internal/reference/pooling.h>
25 namespace luci_interpreter
31 AveragePool2D::AveragePool2D(const Tensor *input, Tensor *output, const Pool2DParams ¶ms)
32 : KernelWithParams<Pool2DParams>({input}, {output}, params)
36 void AveragePool2D::configure()
38 const Shape &input_shape = input()->shape();
40 const int32_t batches = input_shape.dim(0);
41 const int32_t input_height = input_shape.dim(1);
42 const int32_t input_width = input_shape.dim(2);
43 const int32_t depth = input_shape.dim(3);
45 const int32_t output_height = computeOutputSize(_params.padding, input_height,
46 _params.filter_height, _params.stride_height);
47 const int32_t output_width =
48 computeOutputSize(_params.padding, input_width, _params.filter_width, _params.stride_width);
51 computePadding(_params.stride_height, 1, input_height, _params.filter_height, output_height);
53 computePadding(_params.stride_width, 1, input_width, _params.filter_width, output_width);
55 output()->resize({batches, output_height, output_width, depth});
58 void AveragePool2D::execute() const
60 switch (input()->element_type())
62 case DataType::FLOAT32:
69 throw std::runtime_error("Unsupported type.");
73 void AveragePool2D::evalFloat() const
75 float activation_min{};
76 float activation_max{};
77 calculateActivationRange(_params.activation, &activation_min, &activation_max);
79 tflite::PoolParams params{};
80 params.padding_values.height = _padding_height;
81 params.padding_values.width = _padding_width;
82 params.stride_height = _params.stride_height;
83 params.stride_width = _params.stride_width;
84 params.filter_height = _params.filter_height;
85 params.filter_width = _params.filter_width;
86 params.float_activation_min = activation_min;
87 params.float_activation_max = activation_max;
89 tflite::reference_ops::AveragePool(params, getTensorShape(input()), getTensorData<float>(input()),
90 getTensorShape(output()), getTensorData<float>(output()));
93 void AveragePool2D::evalQuantized() const
95 int32_t activation_min{};
96 int32_t activation_max{};
97 calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
99 tflite::PoolParams params{};
100 params.padding_values.height = _padding_height;
101 params.padding_values.width = _padding_width;
102 params.stride_height = _params.stride_height;
103 params.stride_width = _params.stride_width;
104 params.filter_height = _params.filter_height;
105 params.filter_width = _params.filter_width;
106 params.quantized_activation_min = activation_min;
107 params.quantized_activation_max = activation_max;
109 tflite::reference_ops::AveragePool(params, getTensorShape(input()),
110 getTensorData<uint8_t>(input()), getTensorShape(output()),
111 getTensorData<uint8_t>(output()));
114 } // namespace kernels
115 } // namespace luci_interpreter