7ae24bbba085a31050af529bd8c927e6116ceb5b
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Slice.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Slice.h"
19 #include "Utils.h"
20 #include "PALSlice.h"
21
22 #include <cassert>
23 #include <cstring>
24
25 namespace luci_interpreter
26 {
27
28 namespace kernels
29 {
30 const int max_dim = 4;
31
32 Slice::Slice(const Tensor *input, const Tensor *begin, const Tensor *size, Tensor *output)
33   : Kernel({input, begin, size}, {output})
34 {
35 }
36
37 template <typename T>
38 Shape calculateOutputShape(const Tensor *input, const Tensor *begin, const Tensor *size)
39 {
40   Shape output_shape = Shape(input->shape().num_dims());
41   for (int idx = 0; idx < input->shape().num_dims(); idx++)
42   {
43     T size_value = getTensorData<T>(size)[idx];
44     if (size_value < 0)
45     {
46       if (size_value != -1)
47       {
48         assert(false && "Invalid size.");
49       }
50       size_value = input->shape().dim(idx) - getTensorData<T>(begin)[idx];
51     }
52     else
53     {
54       if (input->shape().dim(idx) < getTensorData<T>(begin)[idx] + size_value)
55       {
56         assert(false && "Invalid begin and size.");
57       }
58     }
59     output_shape.dim(idx) = static_cast<int>(size_value);
60   }
61   return output_shape;
62 }
63
64 template <typename T>
65 void getBeginAndSizeVectors(int dimensions, const Tensor *begin, const Tensor *size,
66                             std::vector<int> *begins, std::vector<int> *sizes)
67 {
68   for (int idx = dimensions - 1; idx >= 0; --idx)
69   {
70     begins->push_back(getTensorData<T>(begin)[idx]);
71     sizes->push_back(getTensorData<T>(size)[idx]);
72   }
73 }
74
75 void Slice::configure()
76 {
77   assert(input()->element_type() == output()->element_type());
78   assert(begin()->element_type() == DataType::S32 || begin()->element_type() == DataType::S64);
79   assert(size()->element_type() == DataType::S32 || size()->element_type() == DataType::S64);
80   assert(begin()->shape().num_dims() == 1);
81   assert(size()->shape().num_dims() == 1);
82   assert(input()->shape().num_dims() <= max_dim);
83   // TODO: enable it only if kernel with dynamic shapes
84   if (begin()->element_type() == DataType::S32)
85   {
86     output()->resize(calculateOutputShape<int32_t>(input(), begin(), size()));
87   }
88   else if (begin()->element_type() == DataType::S64)
89   {
90     output()->resize(calculateOutputShape<int64_t>(input(), begin(), size()));
91   }
92   else
93   {
94     assert(false && "Unsupported type.");
95   }
96 }
97
98 void Slice::execute() const
99 {
100   std::vector<int> begins;
101   begins.reserve(max_dim);
102   std::vector<int> sizes;
103   sizes.reserve(max_dim);
104   if (begin()->element_type() == DataType::S32)
105   {
106     getBeginAndSizeVectors<int32_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
107   }
108   else if (begin()->element_type() == DataType::S64)
109   {
110     getBeginAndSizeVectors<int64_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
111   }
112   else
113   {
114     assert(false && "Unsupported begin type.");
115   }
116   for (int i = input()->shape().num_dims(); i < max_dim; ++i)
117   {
118     begins.push_back(0);
119     sizes.push_back(1);
120   }
121
122   assert(begins.size() == 4);
123   assert(sizes.size() == 4);
124   tflite::SliceParams op_params{};
125   op_params.begin_count = 4;
126   op_params.size_count = 4;
127   for (int i = 0; i < 4; i++)
128   {
129     op_params.begin[i] = begins[3 - i];
130     op_params.size[i] = sizes[3 - i];
131   }
132   switch (input()->element_type())
133   {
134     case DataType::FLOAT32:
135       luci_interpreter_pal::Slice(op_params, getTensorShape(input()), getTensorData<float>(input()),
136                                   getTensorShape(output()), getTensorData<float>(output()));
137       break;
138     case DataType::U8:
139       luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
140                                   getTensorData<uint8_t>(input()), getTensorShape(output()),
141                                   getTensorData<uint8_t>(output()));
142       break;
143     case DataType::S8:
144       luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
145                                   getTensorData<int8_t>(input()), getTensorShape(output()),
146                                   getTensorData<int8_t>(output()));
147       break;
148     default:
149       assert(false && "Unsupported input type.");
150   }
151 }
152
153 } // namespace kernels
154 } // namespace luci_interpreter