2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef LUCI_INTERPRETER_PAL_RESIZEBILINEAR_COMMON_H
18 #define LUCI_INTERPRETER_PAL_RESIZEBILINEAR_COMMON_H
22 namespace luci_interpreter_pal
25 // Offset function for positining corresponding index in input data
26 // int i0 - batches, int i1 - height, int i2 - width, int i3 - depth
27 inline int Offset(const luci_interpreter::RuntimeShape &shape, int i0, int i1, int i2, int i3)
29 assert(shape.dimensionsCount() == 4);
31 const int32_t *dims_data = reinterpret_cast<const int32_t *>(shape.dimsData());
32 LUCI_INTERPRETER_CHECK(i0 >= 0 && i0 < dims_data[0]);
33 LUCI_INTERPRETER_CHECK(i1 >= 0 && i1 < dims_data[1]);
34 LUCI_INTERPRETER_CHECK(i2 >= 0 && i2 < dims_data[2]);
35 LUCI_INTERPRETER_CHECK(i3 >= 0 && i3 < dims_data[3]);
36 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
39 inline void ComputeInterpolationValues(const float value, const float scale,
40 const bool half_pixel_centers, int32_t input_size,
41 float *scaled_value, int32_t *lower_bound,
44 if (half_pixel_centers)
46 *scaled_value = (value + 0.5f) * scale - 0.5f;
50 *scaled_value = value * scale;
52 float scaled_value_floor = std::floor(*scaled_value);
53 *lower_bound = std::max(static_cast<int32_t>(scaled_value_floor), static_cast<int32_t>(0));
54 *upper_bound = std::min(static_cast<int32_t>(std::ceil(*scaled_value)), input_size - 1);
59 ResizeBilinear(const circle::ResizeBilinearOptions *op_params,
60 const luci_interpreter::RuntimeShape &unextended_input_shape, const T *input_data,
61 const luci_interpreter::RuntimeShape &unextended_output_size_shape,
62 const int32_t *output_size_data,
63 const luci_interpreter::RuntimeShape &unextended_output_shape, T *output_data)
65 // If half_pixel_centers is True, align_corners must be False.
66 LUCI_INTERPRETER_CHECK(!op_params->half_pixel_centers() || !op_params->align_corners());
68 assert(unextended_input_shape.dimensionsCount() >= 4);
69 assert(unextended_output_size_shape.dimensionsCount() >= 1);
70 assert(unextended_output_shape.dimensionsCount() >= 4);
71 const luci_interpreter::RuntimeShape input_shape =
72 luci_interpreter::RuntimeShape::extendedShape(4, unextended_input_shape);
73 const luci_interpreter::RuntimeShape output_size_shape =
74 luci_interpreter::RuntimeShape::extendedShape(4, unextended_output_size_shape);
75 const luci_interpreter::RuntimeShape output_shape =
76 luci_interpreter::RuntimeShape::extendedShape(4, unextended_output_shape);
78 int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
79 int32_t input_height = input_shape.dims(1);
80 int32_t input_width = input_shape.dims(2);
81 int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
83 assert(output_size_shape.dims(0) == 1);
84 assert(output_size_shape.dims(1) == 1);
85 assert(output_size_shape.dims(2) == 1);
86 assert(output_size_shape.dims(3) == 2);
88 int32_t output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
89 int32_t output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
91 float height_scale = static_cast<float>(input_height) / output_height;
92 float width_scale = static_cast<float>(input_width) / output_width;
93 if (op_params->align_corners() && output_height > 1)
95 height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
97 if (op_params->align_corners() && output_width > 1)
99 width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
101 const float rounding_offset = std::numeric_limits<T>::is_integer ? .5f : .0f;
103 for (int b = 0; b < batches; ++b)
105 for (int y = 0; y < output_height; ++y)
109 ComputeInterpolationValues(y, height_scale, op_params->half_pixel_centers(), input_height,
111 for (int x = 0; x < output_width; ++x)
115 ComputeInterpolationValues(x, width_scale, op_params->half_pixel_centers(), input_width,
117 for (int c = 0; c < depth; ++c)
119 T interpolation = static_cast<T>(
120 input_data[Offset(input_shape, b, y0, x0, c)] * (1 - (input_y - y0)) *
121 (1 - (input_x - x0)) +
122 input_data[Offset(input_shape, b, y1, x0, c)] * (input_y - y0) * (1 - (input_x - x0)) +
123 input_data[Offset(input_shape, b, y0, x1, c)] * (1 - (input_y - y0)) * (input_x - x0) +
124 input_data[Offset(input_shape, b, y1, x1, c)] * (input_y - y0) * (input_x - x0) +
126 output_data[Offset(output_shape, b, y, x, c)] = interpolation;
133 } // namespace luci_interpreter_pal
135 #endif // LUCI_INTERPRETER_PAL_RESIZEBILINEAR_COMMON_H