Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Split.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Builders.h"
18 #include "Utils.h"
19 #include "Split.h"
20
21 namespace luci_interpreter
22 {
23
24 void configure_kernel_CircleSplit(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
25 {
26   const auto input_index = cur_op->inputs()->operator[](0);
27   const auto axis_index = cur_op->inputs()->operator[](1);
28
29   LUCI_INTERPRETER_CHECK(input_index != -1);
30   LUCI_INTERPRETER_CHECK(axis_index != -1);
31
32   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
33   const auto axis = runtime_graph->getCircleTensorByIndex(axis_index);
34
35   LUCI_INTERPRETER_CHECK(input != nullptr);
36   LUCI_INTERPRETER_CHECK(axis != nullptr);
37 }
38
39 void execute_kernel_CircleSplit(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
40 {
41   const auto input_index = cur_op->inputs()->operator[](1);
42   const auto axis_index = cur_op->inputs()->operator[](0);
43
44   assert(input_index != -1);
45   assert(axis_index != -1);
46
47   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
48   const auto axis = runtime_graph->getCircleTensorByIndex(axis_index);
49
50   assert(input != nullptr);
51   assert(axis != nullptr);
52
53   const auto *axis_data = runtime_graph->getDataByTensor(axis);
54   if (axis_data == nullptr)
55     axis_data = runtime_graph->getConstDataByTensor(axis);
56
57   assert(axis_data);
58
59   int32_t axis_value = (kernels::getTensorData<int32_t>(axis_data))[0];
60   if (axis_value < 0)
61     axis_value += Tensor::num_dims(input);
62
63   assert(axis_value >= 0);
64   assert(axis_value < Tensor::num_dims(input));
65
66   switch (Tensor::element_type(input))
67   {
68 #ifndef DIS_FLOAT
69     case DataType::FLOAT32:
70     {
71       return splitImpl<float>(cur_op, input, axis_value, runtime_graph);
72     }
73 #endif // DIS_FLOAT
74 #ifndef DIS_QUANT
75     case DataType::S8:
76     {
77       return splitImpl<int8_t>(cur_op, input, axis_value, runtime_graph);
78     }
79     case DataType::S16:
80     {
81       return splitImpl<int16_t>(cur_op, input, axis_value, runtime_graph);
82     }
83 #endif // DIS_QUANT
84     case DataType::S32:
85     {
86       return splitImpl<int32_t>(cur_op, input, axis_value, runtime_graph);
87     }
88     default:
89       assert(false && "Unsupported type");
90   }
91 }
92
93 } // namespace luci_interpreter