Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Transpose.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "TISOKernel.h"
20 #include "kernels/Utils.h"
21
22 #include "PALTranspose.h"
23
24 namespace luci_interpreter
25 {
26 void configure_kernel_CircleTranspose(const circle::Operator *cur_op,
27                                       BaseRuntimeGraph *runtime_graph)
28 {
29   kernels::TISOKernel kernel(cur_op, runtime_graph);
30
31   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) == DataType::S32);
32
33   const int32_t dims = Tensor::num_dims(kernel.input1());
34   const int32_t *perm_data =
35     kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(kernel.input2()));
36
37   // Ensure validity of the permutations tensor as a 1D tensor
38   LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input2()) == 1);
39   LUCI_INTERPRETER_CHECK(Tensor::dim(kernel.input2(), 0) == dims);
40
41   for (int idx = 0; idx < dims; ++idx)
42     LUCI_INTERPRETER_CHECK(perm_data[idx] >= 0 and perm_data[idx] < dims);
43 }
44
45 void execute_kernel_CircleTranspose(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
46 {
47   kernels::TISOKernel kernel(cur_op, runtime_graph);
48
49   const circle::Tensor *input = kernel.input1();
50   const circle::Tensor *perm = kernel.input2();
51   const circle::Tensor *output = kernel.output();
52
53   kernels::TISOData tiso_data = kernel.readData();
54   const int32_t *perm_data = kernels::getTensorData<int32_t>(tiso_data.input2_data);
55
56   const int32_t size = Tensor::dim(perm, 0);
57   luci_interpreter_pal::TransposeParams params;
58   params.perm_count = size;
59   for (int i = 0; i < size; ++i)
60     params.perm[i] = perm_data[i];
61
62   switch (Tensor::element_type(input))
63   {
64 #ifndef DIS_FLOAT
65     case DataType::FLOAT32:
66       luci_interpreter_pal::Transpose(params, kernels::getTensorShape(input),
67                                       kernels::getTensorData<float>(tiso_data.input1_data),
68                                       kernels::getTensorShape(output),
69                                       kernels::getTensorData<float>(tiso_data.output_data));
70       break;
71 #endif // DIS_FLOAT
72 #ifndef DIS_QUANT
73     case DataType::U8:
74       luci_interpreter_pal::Transpose(params, kernels::getTensorShape(input),
75                                       kernels::getTensorData<uint8_t>(tiso_data.input1_data),
76                                       kernels::getTensorShape(output),
77                                       kernels::getTensorData<uint8_t>(tiso_data.output_data));
78       break;
79 #endif // DIS_QUANT
80     default:
81       assert(false && "Unsupported type");
82   }
83 }
84
85 } // namespace luci_interpreter