f01cf10b74962b0e8f6572d10e033b45dc018781
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / SplitV.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "SplitV.h"
19
20 #include "Utils.h"
21
22 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
23
24 namespace luci_interpreter
25 {
26 namespace kernels
27 {
28
29 SplitV::SplitV(const Tensor *input, const Tensor *size_splits, const Tensor *axis,
30                std::vector<Tensor *> outputs)
31   : Kernel({input, size_splits, axis}, std::move(outputs))
32 {
33 }
34
35 void SplitV::configure()
36 {
37   assert(axis()->shape().num_elements() == 1);
38   _axis_value = getTensorData<int32_t>(axis())[0];
39   if (_axis_value < 0)
40     _axis_value += input()->shape().num_dims();
41   assert(_axis_value >= 0 && _axis_value < input()->shape().num_dims());
42
43   auto num_split = static_cast<int32_t>(_outputs.size());
44   auto sizes_data = getTensorData<int32_t>(size_splits());
45
46   assert(size_splits()->shape().num_dims() == 1);
47
48   int32_t sum = 0;
49   const auto num_dims_size_spits = size_splits()->shape().dim(0);
50   int32_t count_neg_dim = 0;
51
52   for (int32_t i = 0; i < num_dims_size_spits - 1; ++i)
53   {
54     if (sizes_data[i] != -1)
55     {
56       sum += sizes_data[i];
57     }
58     else
59     {
60       count_neg_dim++;
61     }
62   }
63   assert(count_neg_dim < 2);
64   assert(size_splits()->shape().num_elements() == num_split);
65
66   // TODO: enable it only if kernel with dynamic shapes
67   auto output_shape = input()->shape();
68   for (int32_t i = 0; i < num_split; ++i)
69   {
70     if (sizes_data[i] == -1)
71     {
72       output_shape.dim(_axis_value) = input()->shape().dim(_axis_value) - sum;
73     }
74     else
75     {
76       output_shape.dim(_axis_value) = sizes_data[i];
77     }
78     _outputs[i]->resize(output_shape);
79   }
80 }
81
82 void SplitV::execute() const
83 {
84   tflite::SplitParams params{};
85   params.num_split = _outputs.size();
86   params.axis = _axis_value;
87
88 #define TF_LITE_SPLIT(scalar)                                                                     \
89   {                                                                                               \
90     VectorOfTensors<scalar, false> all_outputs(_outputs);                                         \
91     tflite::optimized_ops::Split(params, getTensorShape(input()), getTensorData<scalar>(input()), \
92                                  all_outputs.shapes(), all_outputs.data());                       \
93   }
94
95   switch (input()->element_type())
96   {
97     case DataType::FLOAT32:
98       TF_LITE_SPLIT(float);
99       break;
100     case DataType::U8:
101       TF_LITE_SPLIT(uint8_t);
102       break;
103     case DataType::S16:
104       TF_LITE_SPLIT(int16_t);
105       break;
106     default:
107       assert(false && "Unsupported type.");
108   }
109 #undef TF_LITE_SPLIT
110 }
111
112 } // namespace kernels
113 } // namespace luci_interpreter