2b759a829ccc20a8c7b8788eceb0b16ac31c8795
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / PadV2.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/PadV2.h"
19
20 #include "kernels/Utils.h"
21
22 #include <tensorflow/lite/kernels/internal/reference/pad.h>
23
24 #include <limits>
25
26 namespace luci_interpreter
27 {
28 namespace kernels
29 {
30
31 PadV2::PadV2(const Tensor *input, const Tensor *paddings, const Tensor *constant_values,
32              Tensor *output)
33   : Kernel({input, paddings, constant_values}, {output})
34 {
35 }
36
37 void PadV2::configure()
38 {
39   const Shape &input_shape = input()->shape();
40   const int num_dims = input_shape.num_dims();
41
42   if (num_dims > 4)
43     assert(false && "Unsupported number of dimensions.");
44
45   assert(output()->element_type() == input()->element_type());
46   assert(paddings()->element_type() == DataType::S32);
47   assert(constant_values()->element_type() == output()->element_type());
48   // Paddings shape should be [N, 2].
49   assert(paddings()->shape().num_dims() == 2);
50   assert(paddings()->shape().dim(0) == num_dims);
51   assert(paddings()->shape().dim(1) == 2);
52   // Constant values elements number should be 1.
53   assert(constant_values()->shape().num_elements() == 1);
54
55   Shape output_shape(num_dims);
56   const auto *paddings_data = getTensorData<int32_t>(paddings());
57   for (int i = 0; i < num_dims; ++i)
58   {
59     const int32_t padding_before = paddings_data[i * 2];
60     const int32_t padding_after = paddings_data[i * 2 + 1];
61     assert(padding_before >= 0 && padding_after >= 0);
62     output_shape.dim(i) = input_shape.dim(i) + padding_before + padding_after;
63   }
64   // TODO: enable it only if kernel with dynamic shapes
65   output()->resize(output_shape);
66 }
67
68 void PadV2::execute() const
69 {
70   const int num_dims = input()->shape().num_dims();
71
72   tflite::PadParams params{};
73   params.left_padding_count = num_dims;
74   params.right_padding_count = num_dims;
75
76   const auto *paddings_data = getTensorData<int32_t>(paddings());
77   for (int i = num_dims - 1; i >= 0; --i)
78   {
79     params.left_padding[i] = paddings_data[i * 2];
80     params.right_padding[i] = paddings_data[i * 2 + 1];
81   }
82
83   switch (input()->element_type())
84   {
85     case DataType::FLOAT32:
86     {
87       const auto pad_value = getTensorData<float>(constant_values())[0];
88       tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<float>(input()),
89                                  &pad_value, getTensorShape(output()),
90                                  getTensorData<float>(output()));
91       break;
92     }
93     case DataType::U8:
94     {
95       assert(output()->zero_point() >= std::numeric_limits<uint8_t>::min());
96       assert(output()->zero_point() <= std::numeric_limits<uint8_t>::max());
97       const auto pad_value = getTensorData<uint8_t>(constant_values())[0];
98       tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<uint8_t>(input()),
99                                  &pad_value, getTensorShape(output()),
100                                  getTensorData<uint8_t>(output()));
101       break;
102     }
103     default:
104       assert(false && "Unsupported type.");
105   }
106 }
107
108 } // namespace kernels
109 } // namespace luci_interpreter