Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / common / PALResizeBilinear.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef LUCI_INTERPRETER_PAL_RESIZEBILINEAR_COMMON_H
18 #define LUCI_INTERPRETER_PAL_RESIZEBILINEAR_COMMON_H
19
20 #include "PALUtils.h"
21
22 namespace luci_interpreter_pal
23 {
24
25 // Offset function for positining corresponding index in input data
26 // int i0 - batches, int i1 - height, int i2 - width, int i3 - depth
27 inline int Offset(const luci_interpreter::RuntimeShape &shape, int i0, int i1, int i2, int i3)
28 {
29   assert(shape.dimensionsCount() == 4);
30
31   const int32_t *dims_data = reinterpret_cast<const int32_t *>(shape.dimsData());
32   LUCI_INTERPRETER_CHECK(i0 >= 0 && i0 < dims_data[0]);
33   LUCI_INTERPRETER_CHECK(i1 >= 0 && i1 < dims_data[1]);
34   LUCI_INTERPRETER_CHECK(i2 >= 0 && i2 < dims_data[2]);
35   LUCI_INTERPRETER_CHECK(i3 >= 0 && i3 < dims_data[3]);
36   return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
37 }
38
39 inline void ComputeInterpolationValues(const float value, const float scale,
40                                        const bool half_pixel_centers, int32_t input_size,
41                                        float *scaled_value, int32_t *lower_bound,
42                                        int32_t *upper_bound)
43 {
44   if (half_pixel_centers)
45   {
46     *scaled_value = (value + 0.5f) * scale - 0.5f;
47   }
48   else
49   {
50     *scaled_value = value * scale;
51   }
52   float scaled_value_floor = std::floor(*scaled_value);
53   *lower_bound = std::max(static_cast<int32_t>(scaled_value_floor), static_cast<int32_t>(0));
54   *upper_bound = std::min(static_cast<int32_t>(std::ceil(*scaled_value)), input_size - 1);
55 }
56
57 template <typename T>
58 static inline void
59 ResizeBilinear(const circle::ResizeBilinearOptions *op_params,
60                const luci_interpreter::RuntimeShape &unextended_input_shape, const T *input_data,
61                const luci_interpreter::RuntimeShape &unextended_output_size_shape,
62                const int32_t *output_size_data,
63                const luci_interpreter::RuntimeShape &unextended_output_shape, T *output_data)
64 {
65   // If half_pixel_centers is True, align_corners must be False.
66   LUCI_INTERPRETER_CHECK(!op_params->half_pixel_centers() || !op_params->align_corners());
67
68   assert(unextended_input_shape.dimensionsCount() >= 4);
69   assert(unextended_output_size_shape.dimensionsCount() >= 1);
70   assert(unextended_output_shape.dimensionsCount() >= 4);
71   const luci_interpreter::RuntimeShape input_shape =
72     luci_interpreter::RuntimeShape::extendedShape(4, unextended_input_shape);
73   const luci_interpreter::RuntimeShape output_size_shape =
74     luci_interpreter::RuntimeShape::extendedShape(4, unextended_output_size_shape);
75   const luci_interpreter::RuntimeShape output_shape =
76     luci_interpreter::RuntimeShape::extendedShape(4, unextended_output_shape);
77
78   int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
79   int32_t input_height = input_shape.dims(1);
80   int32_t input_width = input_shape.dims(2);
81   int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
82
83   assert(output_size_shape.dims(0) == 1);
84   assert(output_size_shape.dims(1) == 1);
85   assert(output_size_shape.dims(2) == 1);
86   assert(output_size_shape.dims(3) == 2);
87
88   int32_t output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
89   int32_t output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
90
91   float height_scale = static_cast<float>(input_height) / output_height;
92   float width_scale = static_cast<float>(input_width) / output_width;
93   if (op_params->align_corners() && output_height > 1)
94   {
95     height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
96   }
97   if (op_params->align_corners() && output_width > 1)
98   {
99     width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
100   }
101   const float rounding_offset = std::numeric_limits<T>::is_integer ? .5f : .0f;
102
103   for (int b = 0; b < batches; ++b)
104   {
105     for (int y = 0; y < output_height; ++y)
106     {
107       float input_y;
108       int32_t y0, y1;
109       ComputeInterpolationValues(y, height_scale, op_params->half_pixel_centers(), input_height,
110                                  &input_y, &y0, &y1);
111       for (int x = 0; x < output_width; ++x)
112       {
113         float input_x;
114         int32_t x0, x1;
115         ComputeInterpolationValues(x, width_scale, op_params->half_pixel_centers(), input_width,
116                                    &input_x, &x0, &x1);
117         for (int c = 0; c < depth; ++c)
118         {
119           T interpolation = static_cast<T>(
120             input_data[Offset(input_shape, b, y0, x0, c)] * (1 - (input_y - y0)) *
121               (1 - (input_x - x0)) +
122             input_data[Offset(input_shape, b, y1, x0, c)] * (input_y - y0) * (1 - (input_x - x0)) +
123             input_data[Offset(input_shape, b, y0, x1, c)] * (1 - (input_y - y0)) * (input_x - x0) +
124             input_data[Offset(input_shape, b, y1, x1, c)] * (input_y - y0) * (input_x - x0) +
125             rounding_offset);
126           output_data[Offset(output_shape, b, y, x, c)] = interpolation;
127         }
128       }
129     }
130   }
131 }
132
133 } // namespace luci_interpreter_pal
134
135 #endif // LUCI_INTERPRETER_PAL_RESIZEBILINEAR_COMMON_H