2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SliceLayer.h"
19 #include "OperationUtils.h"
21 #include <cker/operation/Slice.h>
32 SliceLayer::SliceLayer() : _input(nullptr), _begin(nullptr), _size(nullptr), _output(nullptr)
38 void SliceLayer::GetBeginAndSizeVectors(int dimensions, const IPortableTensor *begin,
39 const IPortableTensor *size, std::vector<int> *begins,
40 std::vector<int> *sizes)
42 for (int idx = dimensions - 1; idx >= 0; --idx)
44 begins->push_back(reinterpret_cast<T *>(begin->buffer())[idx]);
45 sizes->push_back(reinterpret_cast<T *>(size->buffer())[idx]);
49 template <typename T> void SliceLayer::sliceImpl()
51 const int kMaxDim = nnfw::cker::Shape::kMaxSmallSize;
53 std::vector<int> begins;
54 std::vector<int> sizes;
55 begins.reserve(kMaxDim);
56 sizes.reserve(kMaxDim);
58 GetBeginAndSizeVectors<int32_t>(_input->num_dimensions(), _begin, _size, &begins, &sizes);
60 // begins : 0-based, sizes : 1-based
61 for (int i = _input->num_dimensions(); i < kMaxDim; ++i)
67 nnfw::cker::SliceParams op_params;
68 op_params.begin_count = 4;
69 op_params.size_count = 4;
70 for (int i = 0; i < 4; ++i)
72 op_params.begin[i] = begins[3 - i];
73 op_params.size[i] = sizes[3 - i];
76 nnfw::cker::Slice(op_params, getExtendedTensorShape(_input),
77 reinterpret_cast<const T *>(_input->buffer()),
78 reinterpret_cast<T *>(_output->buffer()));
81 void SliceLayer::configure(const IPortableTensor *input, const IPortableTensor *begin,
82 const IPortableTensor *size, IPortableTensor *output)
90 void SliceLayer::run()
92 if (_input->data_type() == OperandType::FLOAT32)
96 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
102 throw std::runtime_error{"Slice: unsupported data type"};
108 } // namespace backend