2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Slice.h"
25 namespace luci_interpreter
30 const int max_dim = 4;
32 Slice::Slice(const Tensor *input, const Tensor *begin, const Tensor *size, Tensor *output)
33 : Kernel({input, begin, size}, {output})
38 Shape calculateOutputShape(const Tensor *input, const Tensor *begin, const Tensor *size)
40 Shape output_shape = Shape(input->shape().num_dims());
41 for (int idx = 0; idx < input->shape().num_dims(); idx++)
43 T size_value = getTensorData<T>(size)[idx];
48 assert(false && "Invalid size.");
50 size_value = input->shape().dim(idx) - getTensorData<T>(begin)[idx];
54 if (input->shape().dim(idx) < getTensorData<T>(begin)[idx] + size_value)
56 assert(false && "Invalid begin and size.");
59 output_shape.dim(idx) = static_cast<int>(size_value);
65 void getBeginAndSizeVectors(int dimensions, const Tensor *begin, const Tensor *size,
66 std::vector<int> *begins, std::vector<int> *sizes)
68 for (int idx = dimensions - 1; idx >= 0; --idx)
70 begins->push_back(getTensorData<T>(begin)[idx]);
71 sizes->push_back(getTensorData<T>(size)[idx]);
75 void Slice::configure()
77 assert(input()->element_type() == output()->element_type());
78 assert(begin()->element_type() == DataType::S32 || begin()->element_type() == DataType::S64);
79 assert(size()->element_type() == DataType::S32 || size()->element_type() == DataType::S64);
80 assert(begin()->shape().num_dims() == 1);
81 assert(size()->shape().num_dims() == 1);
82 assert(input()->shape().num_dims() <= max_dim);
83 // TODO: enable it only if kernel with dynamic shapes
84 if (begin()->element_type() == DataType::S32)
86 output()->resize(calculateOutputShape<int32_t>(input(), begin(), size()));
88 else if (begin()->element_type() == DataType::S64)
90 output()->resize(calculateOutputShape<int64_t>(input(), begin(), size()));
94 assert(false && "Unsupported type.");
98 void Slice::execute() const
100 std::vector<int> begins;
101 begins.reserve(max_dim);
102 std::vector<int> sizes;
103 sizes.reserve(max_dim);
104 if (begin()->element_type() == DataType::S32)
106 getBeginAndSizeVectors<int32_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
108 else if (begin()->element_type() == DataType::S64)
110 getBeginAndSizeVectors<int64_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
114 assert(false && "Unsupported begin type.");
116 for (int i = input()->shape().num_dims(); i < max_dim; ++i)
122 assert(begins.size() == 4);
123 assert(sizes.size() == 4);
124 tflite::SliceParams op_params{};
125 op_params.begin_count = 4;
126 op_params.size_count = 4;
127 for (int i = 0; i < 4; i++)
129 op_params.begin[i] = begins[3 - i];
130 op_params.size[i] = sizes[3 - i];
132 switch (input()->element_type())
134 case DataType::FLOAT32:
135 luci_interpreter_pal::Slice(op_params, getTensorShape(input()), getTensorData<float>(input()),
136 getTensorShape(output()), getTensorData<float>(output()));
139 luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
140 getTensorData<uint8_t>(input()), getTensorShape(output()),
141 getTensorData<uint8_t>(output()));
144 luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
145 getTensorData<int8_t>(input()), getTensorShape(output()),
146 getTensorData<int8_t>(output()));
149 assert(false && "Unsupported input type.");
153 } // namespace kernels
154 } // namespace luci_interpreter