Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Split.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Split.h"
18
19 #include "Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
22
23 namespace luci_interpreter
24 {
25 namespace kernels
26 {
27
28 Split::Split(const Tensor *axis, const Tensor *input, std::vector<Tensor *> outputs)
29     : Kernel({axis, input}, std::move(outputs))
30 {
31 }
32
33 void Split::configure()
34 {
35   assert(axis()->shape().num_elements() == 1);
36   _axis_value = getTensorData<int32_t>(axis())[0];
37   if (_axis_value < 0)
38     _axis_value += input()->shape().num_dims();
39   assert(_axis_value >= 0 && _axis_value < input()->shape().num_dims());
40
41   const int32_t input_size = input()->shape().dim(_axis_value);
42   assert(input_size % _outputs.size() == 0);
43   const int32_t slice_size = input_size / _outputs.size();
44
45   Shape output_shape = input()->shape();
46   output_shape.dim(_axis_value) = slice_size;
47   for (Tensor *output : _outputs)
48   {
49     output->resize(output_shape);
50   }
51 }
52
53 void Split::execute() const
54 {
55   tflite::SplitParams params{};
56   params.num_split = _outputs.size();
57   params.axis = _axis_value;
58
59 #define TF_LITE_SPLIT(scalar)                                                                     \
60   {                                                                                               \
61     VectorOfTensors<scalar, false> all_outputs(_outputs);                                         \
62     tflite::optimized_ops::Split(params, getTensorShape(input()), getTensorData<scalar>(input()), \
63                                  all_outputs.shapes(), all_outputs.data());                       \
64   }
65
66   switch (input()->element_type())
67   {
68     case DataType::FLOAT32:
69       TF_LITE_SPLIT(float);
70       break;
71     case DataType::U8:
72       TF_LITE_SPLIT(uint8_t);
73       break;
74     default:
75       throw std::runtime_error("Unsupported type.");
76   }
77 #undef TF_LITE_SPLIT
78 }
79
80 } // namespace kernels
81 } // namespace luci_interpreter