2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2017 ARM Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 #include "arm_compute/runtime/CL/functions/CLSplitVEx.h"
41 #include "support/ToolchainSupport.h"
42 #include "arm_compute/core/Error.h"
43 #include "arm_compute/core/Helpers.h"
44 #include "arm_compute/core/CL/ICLTensor.h"
45 #include "arm_compute/core/TensorInfo.h"
46 #include "arm_compute/core/Types.h"
47 #include "arm_compute/core/Validate.h"
48 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
49 #include "arm_compute/runtime/CL/CLScheduler.h"
50 #include "src/core/helpers/AutoConfiguration.h"
53 using namespace arm_compute;
57 Status validate_arguments(const ICLTensor *size_splits, const std::vector<ICLTensor *> &outputs,
58 unsigned int num_splits)
60 ARM_COMPUTE_RETURN_ERROR_ON_MSG(size_splits->info()->num_dimensions() != 1,
61 "size_splits must be a 1-D tensor.");
62 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_splits != outputs.size(),
63 "Number of output tensors does not match number of splits.");
67 Status validate_slices(const ITensorInfo *input, const std::vector<ITensorInfo *> &outputs,
70 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
71 ARM_COMPUTE_RETURN_ERROR_ON(split_dim >= input->num_dimensions());
72 ARM_COMPUTE_RETURN_ERROR_ON(outputs.size() < 2);
74 // Start/End coordinates
75 Coordinates start_coords;
76 Coordinates end_coords;
77 for (unsigned int d = 0; d < input->num_dimensions(); ++d)
79 end_coords.set(d, -1);
81 unsigned int axis_offset = 0;
82 // Validate output tensors
83 for (const auto &output : outputs)
85 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
87 const TensorShape output_shape = output->tensor_shape();
88 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
90 const size_t axis_split_step = output_shape[split_dim];
92 // Output auto inizialitation if not yet initialized
93 TensorInfo tmp_output_info = *output->clone();
94 auto_init_if_empty(tmp_output_info,
95 input->clone()->set_is_resizable(true).set_tensor_shape(output_shape));
97 // Update coordinate on axis
98 start_coords.set(split_dim, axis_offset);
99 end_coords.set(split_dim, axis_offset + axis_split_step);
101 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(input, output, start_coords, end_coords));
103 axis_offset += axis_split_step;
109 void configure_slices(const ICLTensor *input, const std::vector<ICLTensor *> &outputs,
110 std::vector<CLSlice> &_slice_functions, uint32_t split_dim)
112 unsigned int axis_offset = 0;
113 // Start/End coordinates
114 Coordinates start_coords;
115 Coordinates end_coords;
116 for (unsigned int d = 0; d < input->info()->num_dimensions(); ++d)
118 end_coords.set(d, -1);
121 for (const auto &output : outputs)
123 const TensorShape output_shape = output->info()->tensor_shape();
124 auto op_size = output_shape.total_size();
130 assert(op_size != 0);
131 assert(split_dim <= output_shape.num_dimensions());
133 const size_t axis_split_step = output_shape[split_dim];
135 // Output auto inizialitation if not yet initialized
136 TensorInfo tmp_output_info = *output->info()->clone();
139 input->info()->clone()->set_is_resizable(true).set_tensor_shape(output_shape));
141 // Update coordinate on axis
142 start_coords.set(split_dim, axis_offset);
143 end_coords.set(split_dim, axis_offset + axis_split_step);
145 // Configure slice function
146 _slice_functions[out_iter].configure(input, output, start_coords, end_coords);
148 // Set valid region from shape
149 outputs[out_iter++]->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
150 axis_offset += axis_split_step;
156 CLSplitVEx::CLSplitVEx()
157 : _input(nullptr), _size_splits(nullptr), _outputs(), _num_splits(0), _slice_functions()
161 void CLSplitVEx::configure(const ICLTensor *input, const ICLTensor *size_splits, uint32_t split_dim,
162 const std::vector<ICLTensor *> &outputs, unsigned int num_splits)
164 ARM_COMPUTE_ERROR_ON_NULLPTR(input, size_splits);
165 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(size_splits, outputs, num_splits));
168 _size_splits = size_splits;
170 _num_splits = num_splits;
172 // Create tensor slices
173 _slice_functions.resize(_num_splits);
175 // Extract output tensor info
176 std::vector<ITensorInfo *> outputs_info;
177 for (auto &&output : _outputs)
179 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
180 outputs_info.emplace_back(output->info());
184 ARM_COMPUTE_ERROR_THROW_ON(validate_slices(_input->info(), outputs_info, split_dim));
187 configure_slices(_input, _outputs, _slice_functions, split_dim);
190 void CLSplitVEx::run()
192 // execute the slices
193 for (unsigned i = 0; i < _outputs.size(); ++i)
195 _slice_functions[i].run();