9737a014997bb697bbff98c3e5e0659e6ec61902
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Split.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Split.h"
19
20 #include "Utils.h"
21
22 #include "PALSplit.h"
23
24 namespace luci_interpreter
25 {
26 namespace kernels
27 {
28
29 Split::Split(const Tensor *axis, const Tensor *input, std::vector<Tensor *> outputs)
30   : Kernel({axis, input}, std::move(outputs))
31 {
32 }
33
34 void Split::configure()
35 {
36   assert(axis()->shape().num_elements() == 1);
37   _axis_value = getTensorData<int32_t>(axis())[0];
38   if (_axis_value < 0)
39     _axis_value += input()->shape().num_dims();
40   assert(_axis_value >= 0 && _axis_value < input()->shape().num_dims());
41
42   const int32_t input_size = input()->shape().dim(_axis_value);
43   assert(input_size % _outputs.size() == 0);
44   const int32_t slice_size = input_size / _outputs.size();
45   // TODO: enable it only if kernel with dynamic shapes
46   Shape output_shape = input()->shape();
47   output_shape.dim(_axis_value) = slice_size;
48   for (Tensor *output : _outputs)
49   {
50     output->resize(output_shape);
51   }
52 }
53
54 void Split::execute() const
55 {
56   tflite::SplitParams params{};
57   params.num_split = _outputs.size();
58   params.axis = _axis_value;
59
60 #define TF_LITE_SPLIT(scalar)                                                                    \
61   {                                                                                              \
62     VectorOfTensors<scalar, false> all_outputs(_outputs);                                        \
63     luci_interpreter_pal::Split(params, getTensorShape(input()), getTensorData<scalar>(input()), \
64                                 all_outputs.shapes(), all_outputs.data());                       \
65   }
66
67   switch (input()->element_type())
68   {
69     case DataType::FLOAT32:
70       TF_LITE_SPLIT(float);
71       break;
72     case DataType::U8:
73       TF_LITE_SPLIT(uint8_t);
74       break;
75     default:
76       assert(false && "Unsupported type.");
77   }
78 #undef TF_LITE_SPLIT
79 }
80
81 } // namespace kernels
82 } // namespace luci_interpreter