2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/Slice.h"
24 namespace luci_interpreter
29 const int max_dim = 4;
31 Slice::Slice(const Tensor *input, const Tensor *begin, const Tensor *size, Tensor *output)
32 : Kernel({input, begin, size}, {output})
37 Shape calculateOutputShape(const Tensor *input, const Tensor *begin, const Tensor *size)
39 Shape output_shape = Shape(input->shape().num_dims());
40 for (int idx = 0; idx < input->shape().num_dims(); idx++)
42 T size_value = getTensorData<T>(size)[idx];
47 throw std::runtime_error("Invalid size.");
49 size_value = input->shape().dim(idx) - getTensorData<T>(begin)[idx];
53 if (input->shape().dim(idx) < getTensorData<T>(begin)[idx] + size_value)
55 throw std::runtime_error("Invalid begin and size.");
58 output_shape.dim(idx) = static_cast<int>(size_value);
64 void getBeginAndSizeVectors(int dimensions, const Tensor *begin, const Tensor *size,
65 std::vector<int> *begins, std::vector<int> *sizes)
67 for (int idx = dimensions - 1; idx >= 0; --idx)
69 begins->push_back(getTensorData<T>(begin)[idx]);
70 sizes->push_back(getTensorData<T>(size)[idx]);
74 void Slice::configure()
76 assert(input()->element_type() == output()->element_type());
77 assert(begin()->element_type() == DataType::S32 || begin()->element_type() == DataType::S64);
78 assert(size()->element_type() == DataType::S32 || size()->element_type() == DataType::S64);
79 assert(begin()->shape().num_dims() == 1);
80 assert(size()->shape().num_dims() == 1);
81 assert(input()->shape().num_dims() <= max_dim);
83 if (begin()->element_type() == DataType::S32)
85 output()->resize(calculateOutputShape<int32_t>(input(), begin(), size()));
87 else if (begin()->element_type() == DataType::S64)
89 output()->resize(calculateOutputShape<int64_t>(input(), begin(), size()));
93 throw std::runtime_error("Unsupported type.");
97 void Slice::execute() const
99 std::vector<int> begins;
100 begins.reserve(max_dim);
101 std::vector<int> sizes;
102 sizes.reserve(max_dim);
103 if (begin()->element_type() == DataType::S32)
105 getBeginAndSizeVectors<int32_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
107 else if (begin()->element_type() == DataType::S64)
109 getBeginAndSizeVectors<int64_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
113 throw std::runtime_error("Unsupported begin type.");
115 for (int i = input()->shape().num_dims(); i < max_dim; ++i)
121 assert(begins.size() == 4);
122 assert(sizes.size() == 4);
123 tflite::SliceParams op_params{};
124 op_params.begin_count = 4;
125 op_params.size_count = 4;
126 for (int i = 0; i < 4; i++)
128 op_params.begin[i] = begins[3 - i];
129 op_params.size[i] = sizes[3 - i];
131 switch (input()->element_type())
133 case DataType::FLOAT32:
134 luci_interpreter_pal::Slice(op_params, getTensorShape(input()), getTensorData<float>(input()),
135 getTensorShape(output()), getTensorData<float>(output()));
138 luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
139 getTensorData<uint8_t>(input()), getTensorShape(output()),
140 getTensorData<uint8_t>(output()));
143 throw std::runtime_error("Unsupported input type.");
147 } // namespace kernels
148 } // namespace luci_interpreter