2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SliceLayer.h"
19 #include "OperationUtils.h"
21 #include <cker/operation/Slice.h>
32 SliceLayer::SliceLayer() : _input(nullptr), _begin(nullptr), _size(nullptr), _output(nullptr)
38 void SliceLayer::GetBeginAndSizeVectors(int dimensions, const IPortableTensor *begin,
39 const IPortableTensor *size, std::vector<int> *begins,
40 std::vector<int> *sizes)
42 for (int idx = dimensions - 1; idx >= 0; --idx)
44 begins->push_back(getBuffer<T>(begin)[idx]);
45 sizes->push_back(getBuffer<T>(size)[idx]);
49 template <typename T> void SliceLayer::sliceImpl()
51 const int kMaxDim = nnfw::cker::Shape::kMaxSmallSize;
53 std::vector<int> begins;
54 std::vector<int> sizes;
55 begins.reserve(kMaxDim);
56 sizes.reserve(kMaxDim);
58 if (_begin->data_type() == OperandType::INT32)
60 GetBeginAndSizeVectors<int32_t>(_input->getShape().rank(), _begin, _size, &begins, &sizes);
62 else if (_begin->data_type() == OperandType::INT64)
64 GetBeginAndSizeVectors<int64_t>(_input->getShape().rank(), _begin, _size, &begins, &sizes);
68 throw std::runtime_error{"Slice: unsupported begin and/or size data type"};
71 // begins : 0-based, sizes : 1-based
72 for (int i = _input->getShape().rank(); i < kMaxDim; ++i)
78 nnfw::cker::SliceParams op_params;
79 op_params.begin_count = 4;
80 op_params.size_count = 4;
81 for (int i = 0; i < 4; ++i)
83 op_params.begin[i] = begins[3 - i];
84 op_params.size[i] = sizes[3 - i];
87 nnfw::cker::Slice(op_params, getExtendedTensorShape(_input), getBuffer<T>(_input),
88 getBuffer<T>(_output));
91 void SliceLayer::configure(const IPortableTensor *input, const IPortableTensor *begin,
92 const IPortableTensor *size, IPortableTensor *output)
100 void SliceLayer::run()
102 if (_input->data_type() == OperandType::FLOAT32)
106 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
108 sliceImpl<uint8_t>();
112 throw std::runtime_error{"Slice: unsupported data type"};
118 } // namespace backend