Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / common / PALTranspose.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_TRANSPOSE_H
19 #define LUCI_INTERPRETER_PAL_TRANSPOSE_H
20
21 #include "PALUtils.h"
22 #include "ProcessBroadcastShapes.h"
23
24 namespace luci_interpreter_pal
25 {
26 template <typename T, int N>
27 void TransposeImpl(const TransposeParams &params,
28                    const luci_interpreter::RuntimeShape &unextended_input_shape,
29                    const T *input_data,
30                    const luci_interpreter::RuntimeShape &unextended_output_shape, T *output_data)
31 {
32   const int unextended_input_size = unextended_input_shape.dimensionsCount();
33   const int unextended_output_size = unextended_output_shape.dimensionsCount();
34
35   const int input_ext_size = N - unextended_input_size;
36   const int output_ext_size = N - unextended_output_size;
37   NdArrayDesc<N> input_desc;
38   NdArrayDesc<N> output_desc;
39   copyDimsToDesc(luci_interpreter::RuntimeShape::extendedShape(N, unextended_input_shape),
40                  &input_desc);
41   copyDimsToDesc(luci_interpreter::RuntimeShape::extendedShape(N, unextended_output_shape),
42                  &output_desc);
43
44   // The perm data is extended to match the output, each index incremented by
45   // the amount of front padding of the input shape.
46   int extended_perm[N];
47   for (int i = 0; i < N; ++i)
48   {
49     extended_perm[i] = i < output_ext_size ? i : params.perm[i - output_ext_size] + input_ext_size;
50   }
51
52   // Permutes the input shape so we don't need to permute the indexes inside
53   // the loop. Check to make sure output_dims is matching input_dims.
54   NdArrayDesc<N> perm_input_desc;
55   for (int k = 0; k < N; ++k)
56   {
57     perm_input_desc.extents[k] = input_desc.extents[extended_perm[k]];
58     perm_input_desc.strides[k] = input_desc.strides[extended_perm[k]];
59   }
60
61   // Naive transpose loop (iterate on output index and compute input index).
62   auto tranpose_func = [&](int indexes[N]) {
63     output_data[subscriptToIndex(output_desc, indexes)] =
64       input_data[subscriptToIndex(perm_input_desc, indexes)];
65   };
66   NDOpsHelper<N>(output_desc, tranpose_func);
67 }
68
69 template <typename T, int N = 5>
70 void Transpose(const TransposeParams &params,
71                const luci_interpreter::RuntimeShape &unextended_input_shape, const T *input_data,
72                const luci_interpreter::RuntimeShape &unextended_output_shape, T *output_data)
73 {
74   // Transpose kernel only does rearranging values not numeric evaluations on
75   // each cell. It's safe to implement per size of scalar type and this trick
76   // keeps the total code size in a reasonable range.
77   switch (sizeof(T))
78   {
79     case 1:
80       TransposeImpl<int8_t, N>(params, unextended_input_shape,
81                                reinterpret_cast<const int8_t *>(input_data),
82                                unextended_output_shape, reinterpret_cast<int8_t *>(output_data));
83       break;
84     case 2:
85       TransposeImpl<int16_t, N>(params, unextended_input_shape,
86                                 reinterpret_cast<const int16_t *>(input_data),
87                                 unextended_output_shape, reinterpret_cast<int16_t *>(output_data));
88       break;
89
90     case 4:
91       TransposeImpl<int32_t, N>(params, unextended_input_shape,
92                                 reinterpret_cast<const int32_t *>(input_data),
93                                 unextended_output_shape, reinterpret_cast<int32_t *>(output_data));
94       break;
95     case 8:
96       TransposeImpl<int64_t, N>(params, unextended_input_shape,
97                                 reinterpret_cast<const int64_t *>(input_data),
98                                 unextended_output_shape, reinterpret_cast<int64_t *>(output_data));
99       break;
100   }
101 }
102 } // namespace luci_interpreter_pal
103
104 #endif // LUCI_INTERPRETER_PAL_TRANSPOSE_H