Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Slice.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/Slice.h"
18 #include "Utils.h"
19 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
20
21 #include <cassert>
22 #include <cstring>
23
24 namespace luci_interpreter
25 {
26
27 namespace kernels
28 {
29 const int max_dim = 4;
30
31 Slice::Slice(const Tensor *input, const Tensor *begin, const Tensor *size, Tensor *output)
32     : Kernel({input, begin, size}, {output})
33 {
34 }
35
36 template <typename T>
37 Shape calculateOutputShape(const Tensor *input, const Tensor *begin, const Tensor *size)
38 {
39   Shape output_shape = Shape(input->shape().num_dims());
40   for (int idx = 0; idx < input->shape().num_dims(); idx++)
41   {
42     T size_value = getTensorData<T>(size)[idx];
43     if (size_value < 0)
44     {
45       if (size_value != -1)
46       {
47         throw std::runtime_error("Invalid size.");
48       }
49       size_value = input->shape().dim(idx) - getTensorData<T>(begin)[idx];
50     }
51     else
52     {
53       if (input->shape().dim(idx) < getTensorData<T>(begin)[idx] + size_value)
54       {
55         throw std::runtime_error("Invalid begin and size.");
56       }
57     }
58     output_shape.dim(idx) = static_cast<int>(size_value);
59   }
60   return output_shape;
61 }
62
63 template <typename T>
64 void getBeginAndSizeVectors(int dimensions, const Tensor *begin, const Tensor *size,
65                             std::vector<int> *begins, std::vector<int> *sizes)
66 {
67   for (int idx = dimensions - 1; idx >= 0; --idx)
68   {
69     begins->push_back(getTensorData<T>(begin)[idx]);
70     sizes->push_back(getTensorData<T>(size)[idx]);
71   }
72 }
73
74 void Slice::configure()
75 {
76   assert(input()->element_type() == output()->element_type());
77   assert(begin()->element_type() == DataType::S32 || begin()->element_type() == DataType::S64);
78   assert(size()->element_type() == DataType::S32 || size()->element_type() == DataType::S64);
79   assert(begin()->shape().num_dims() == 1);
80   assert(size()->shape().num_dims() == 1);
81   assert(input()->shape().num_dims() <= max_dim);
82
83   if (begin()->element_type() == DataType::S32)
84   {
85     output()->resize(calculateOutputShape<int32_t>(input(), begin(), size()));
86   }
87   else if (begin()->element_type() == DataType::S64)
88   {
89     output()->resize(calculateOutputShape<int64_t>(input(), begin(), size()));
90   }
91   else
92   {
93     throw std::runtime_error("Unsupported type.");
94   }
95 }
96
97 void Slice::execute() const
98 {
99   std::vector<int> begins;
100   begins.reserve(max_dim);
101   std::vector<int> sizes;
102   sizes.reserve(max_dim);
103   if (begin()->element_type() == DataType::S32)
104   {
105     getBeginAndSizeVectors<int32_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
106   }
107   else if (begin()->element_type() == DataType::S64)
108   {
109     getBeginAndSizeVectors<int64_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
110   }
111   else
112   {
113     throw std::runtime_error("Unsupported begin type.");
114   }
115   for (int i = input()->shape().num_dims(); i < max_dim; ++i)
116   {
117     begins.push_back(0);
118     sizes.push_back(1);
119   }
120
121   assert(begins.size() == 4);
122   assert(sizes.size() == 4);
123   tflite::SliceParams op_params{};
124   op_params.begin_count = 4;
125   op_params.size_count = 4;
126   for (int i = 0; i < 4; i++)
127   {
128     op_params.begin[i] = begins[3 - i];
129     op_params.size[i] = sizes[3 - i];
130   }
131   switch (input()->element_type())
132   {
133     case DataType::FLOAT32:
134       tflite::optimized_ops::Slice(op_params, getTensorShape(input()),
135                                    getTensorData<float>(input()), getTensorShape(output()),
136                                    getTensorData<float>(output()));
137       break;
138     case DataType::U8:
139       tflite::optimized_ops::Slice(op_params, getTensorShape(input()),
140                                    getTensorData<uint8_t>(input()), getTensorShape(output()),
141                                    getTensorData<uint8_t>(output()));
142       break;
143     default:
144       throw std::runtime_error("Unsupported input type.");
145   }
146 }
147
148 } // namespace kernels
149 } // namespace luci_interpreter