2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_TRANSPOSE_H
19 #define LUCI_INTERPRETER_PAL_TRANSPOSE_H
22 #include "ProcessBroadcastShapes.h"
24 namespace luci_interpreter_pal
26 template <typename T, int N>
27 void TransposeImpl(const TransposeParams ¶ms,
28 const luci_interpreter::RuntimeShape &unextended_input_shape,
30 const luci_interpreter::RuntimeShape &unextended_output_shape, T *output_data)
32 const int unextended_input_size = unextended_input_shape.dimensionsCount();
33 const int unextended_output_size = unextended_output_shape.dimensionsCount();
35 const int input_ext_size = N - unextended_input_size;
36 const int output_ext_size = N - unextended_output_size;
37 NdArrayDesc<N> input_desc;
38 NdArrayDesc<N> output_desc;
39 copyDimsToDesc(luci_interpreter::RuntimeShape::extendedShape(N, unextended_input_shape),
41 copyDimsToDesc(luci_interpreter::RuntimeShape::extendedShape(N, unextended_output_shape),
44 // The perm data is extended to match the output, each index incremented by
45 // the amount of front padding of the input shape.
47 for (int i = 0; i < N; ++i)
49 extended_perm[i] = i < output_ext_size ? i : params.perm[i - output_ext_size] + input_ext_size;
52 // Permutes the input shape so we don't need to permute the indexes inside
53 // the loop. Check to make sure output_dims is matching input_dims.
54 NdArrayDesc<N> perm_input_desc;
55 for (int k = 0; k < N; ++k)
57 perm_input_desc.extents[k] = input_desc.extents[extended_perm[k]];
58 perm_input_desc.strides[k] = input_desc.strides[extended_perm[k]];
61 // Naive transpose loop (iterate on output index and compute input index).
62 auto tranpose_func = [&](int indexes[N]) {
63 output_data[subscriptToIndex(output_desc, indexes)] =
64 input_data[subscriptToIndex(perm_input_desc, indexes)];
66 NDOpsHelper<N>(output_desc, tranpose_func);
69 template <typename T, int N = 5>
70 void Transpose(const TransposeParams ¶ms,
71 const luci_interpreter::RuntimeShape &unextended_input_shape, const T *input_data,
72 const luci_interpreter::RuntimeShape &unextended_output_shape, T *output_data)
74 // Transpose kernel only does rearranging values not numeric evaluations on
75 // each cell. It's safe to implement per size of scalar type and this trick
76 // keeps the total code size in a reasonable range.
80 TransposeImpl<int8_t, N>(params, unextended_input_shape,
81 reinterpret_cast<const int8_t *>(input_data),
82 unextended_output_shape, reinterpret_cast<int8_t *>(output_data));
85 TransposeImpl<int16_t, N>(params, unextended_input_shape,
86 reinterpret_cast<const int16_t *>(input_data),
87 unextended_output_shape, reinterpret_cast<int16_t *>(output_data));
91 TransposeImpl<int32_t, N>(params, unextended_input_shape,
92 reinterpret_cast<const int32_t *>(input_data),
93 unextended_output_shape, reinterpret_cast<int32_t *>(output_data));
96 TransposeImpl<int64_t, N>(params, unextended_input_shape,
97 reinterpret_cast<const int64_t *>(input_data),
98 unextended_output_shape, reinterpret_cast<int64_t *>(output_data));
102 } // namespace luci_interpreter_pal
104 #endif // LUCI_INTERPRETER_PAL_TRANSPOSE_H