2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/StridedSlice.h"
20 #include "kernels/Utils.h"
22 #include <tensorflow/lite/kernels/internal/reference/strided_slice.h>
26 namespace luci_interpreter
32 StridedSlice::StridedSlice(const Tensor *input, const Tensor *begin, const Tensor *end,
33 const Tensor *strides, Tensor *output, const StridedSliceParams ¶ms)
34 : KernelWithParams<StridedSliceParams>({input, begin, end, strides}, {output}, params)
38 void StridedSlice::configure()
40 assert(begin()->shape().num_dims() == 1);
41 assert(end()->shape().num_dims() == 1);
42 assert(strides()->shape().num_dims() == 1);
43 assert(input()->element_type() == output()->element_type());
44 assert(begin()->element_type() == DataType::S32);
45 assert(end()->element_type() == DataType::S32);
46 assert(strides()->element_type() == DataType::S32);
47 assert(input()->shape().num_dims() <= 4);
48 if (params().ellipsis_mask != 0)
50 throw std::runtime_error("ellipsis_mask is not implemented yet.");
52 if (params().new_axis_mask != 0)
54 throw std::runtime_error("new_axis_mask is not implemented yet.");
56 if (input()->element_type() == DataType::U8)
58 assert(input()->scale() == output()->scale());
59 assert(input()->zero_point() == output()->zero_point());
61 tflite::StridedSliceParams op_params{};
62 op_params.start_indices_count = input()->shape().num_dims();
63 op_params.stop_indices_count = input()->shape().num_dims();
64 op_params.strides_count = input()->shape().num_dims();
66 for (int i = 0; i < input()->shape().num_dims(); i++)
68 op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
69 op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
70 op_params.strides[i] = getTensorData<int32_t>(strides())[i];
72 op_params.begin_mask = params().begin_mask;
73 op_params.ellipsis_mask = 0;
74 op_params.end_mask = params().end_mask;
75 op_params.new_axis_mask = 0;
76 op_params.shrink_axis_mask = params().shrink_axis_mask;
77 std::vector<int32_t> output_shape_vector;
78 for (int i = 0; i < input()->shape().num_dims(); i++)
80 int idx = input()->shape().num_dims() - i - 1;
81 int32_t stride = getTensorData<int32_t>(strides())[idx];
83 int32_t begin = ::tflite::strided_slice::StartForAxis(op_params, getTensorShape(input()), idx);
85 ::tflite::strided_slice::StopForAxis(op_params, getTensorShape(input()), idx, begin);
87 const bool shrink_axis = params().shrink_axis_mask & (1 << idx);
93 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
94 dim_shape = dim_shape < 0 ? 0 : dim_shape;
97 output_shape_vector.push_back(dim_shape);
100 Shape output_shape = Shape(output_shape_vector.size());
101 for (size_t i = 0; i < output_shape_vector.size(); i++)
103 output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1];
105 output()->resize(output_shape);
108 void StridedSlice::execute() const
110 tflite::StridedSliceParams op_params{};
111 op_params.start_indices_count = input()->shape().num_dims();
112 op_params.stop_indices_count = input()->shape().num_dims();
113 op_params.strides_count = input()->shape().num_dims();
115 for (int i = 0; i < input()->shape().num_dims(); i++)
117 op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
118 op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
119 op_params.strides[i] = getTensorData<int32_t>(strides())[i];
121 op_params.begin_mask = params().begin_mask;
122 op_params.ellipsis_mask = 0;
123 op_params.end_mask = params().end_mask;
124 op_params.new_axis_mask = 0;
125 op_params.shrink_axis_mask = params().shrink_axis_mask;
127 switch (input()->element_type())
129 case DataType::FLOAT32:
130 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
131 getTensorData<float>(input()), getTensorShape(output()),
132 getTensorData<float>(output()));
135 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
136 getTensorData<uint8_t>(input()), getTensorShape(output()),
137 getTensorData<uint8_t>(output()));
140 throw std::runtime_error("Unsupported type.");
144 } // namespace kernels
145 } // namespace luci_interpreter