Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Pack.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "Utils.h"
20
21 #include <cassert>
22
23 namespace luci_interpreter
24 {
25 namespace
26 {
27
28 template <typename T>
29 void packImpl(const circle::Tensor *input0, const circle::Tensor *output,
30               const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
31               uint8_t *output_data_raw)
32 {
33   const auto *options = cur_op->builtin_options_as_PackOptions();
34
35   const int values_count = options->values_count();
36   int axis = options->axis();
37   const int dimensions = Tensor::num_dims(output);
38
39   const auto input_dims = wrap(input0->shape());
40   const auto output_dims = wrap(output->shape());
41
42   if (axis < 0)
43   {
44     axis += dimensions;
45   }
46
47   int outer_size = 1;
48   for (int i = 0; i < axis; ++i)
49     outer_size *= output_dims[i];
50
51   int copy_size = 1;
52   for (int i = axis + 1; i < dimensions; ++i)
53     copy_size *= output_dims[i];
54
55   int input_size = 1;
56   for (int i = 0; i < input_dims.size(); ++i)
57     input_size *= input_dims[i];
58
59   assert(input_size == copy_size * outer_size);
60
61   T *output_data = kernels::getTensorData<T>(output_data_raw);
62   assert(output_data != nullptr);
63
64   for (int i = 0; i < values_count; ++i)
65   {
66     const auto input_index = cur_op->inputs()->operator[](i);
67     assert(input_index != -1);
68     const auto input = runtime_graph->getCircleTensorByIndex(input_index);
69
70     auto input_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(input));
71     assert(input_data != nullptr);
72     for (int k = 0; k < outer_size; ++k)
73     {
74       const T *input_ptr = input_data + copy_size * k;
75       int loc = k * values_count * copy_size + i * copy_size;
76       T *output_ptr = output_data + loc;
77       for (int j = 0; j < copy_size; ++j)
78         output_ptr[j] = input_ptr[j];
79     }
80   }
81 }
82
83 } // namespace
84
85 void configure_kernel_CirclePack(const circle::Operator *, BaseRuntimeGraph *)
86 {
87   // Do nothing
88 }
89
90 void execute_kernel_CirclePack(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
91 {
92   const auto input_index = cur_op->inputs()->operator[](0);
93   const auto output_index = cur_op->outputs()->operator[](0);
94   assert(output_index != -1);
95   assert(input_index != -1);
96   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
97   const auto output = runtime_graph->getCircleTensorByIndex(output_index);
98
99   auto output_data = runtime_graph->getDataByTensor(output);
100   assert(output_data != nullptr);
101
102   switch (Tensor::element_type(output))
103   {
104 #ifndef DIS_FLOAT
105     case DataType::FLOAT32:
106       packImpl<float>(input, output, cur_op, runtime_graph, output_data);
107       break;
108 #endif // DIS_FLOAT
109 #ifndef DIS_QUANT
110     case DataType::S8:
111       packImpl<int8_t>(input, output, cur_op, runtime_graph, output_data);
112       break;
113     case DataType::U8:
114       packImpl<uint8_t>(input, output, cur_op, runtime_graph, output_data);
115       break;
116 #endif // DIS_QUANT
117     case DataType::S32:
118       packImpl<int32_t>(input, output, cur_op, runtime_graph, output_data);
119       break;
120     case DataType::S64:
121       packImpl<int64_t>(input, output, cur_op, runtime_graph, output_data);
122       break;
123     default:
124       assert(false && "Unsupported types");
125   }
126 }
127
128 } // namespace luci_interpreter