2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2017 ARM Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 #include "arm_compute/runtime/CL/functions/CLSplitVEx.h"
41 #include "support/ToolchainSupport.h"
42 #include "arm_compute/core/Error.h"
43 #include "arm_compute/core/Helpers.h"
44 #include "arm_compute/core/CL/ICLTensor.h"
45 #include "arm_compute/core/TensorInfo.h"
46 #include "arm_compute/core/Types.h"
47 #include "arm_compute/core/Validate.h"
48 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
49 #include "arm_compute/runtime/CL/CLScheduler.h"
52 using namespace arm_compute;
56 Status validate_arguments(const ICLTensor *size_splits, const std::vector<ICLTensor *> &outputs,
57 unsigned int num_splits)
59 ARM_COMPUTE_RETURN_ERROR_ON_MSG(size_splits->info()->num_dimensions() != 1,
60 "size_splits must be a 1-D tensor.");
61 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_splits != outputs.size(),
62 "Number of output tensors does not match number of splits.");
66 Status validate_slices(const ITensorInfo *input, const std::vector<ITensorInfo *> &outputs,
69 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
70 ARM_COMPUTE_RETURN_ERROR_ON(split_dim >= input->num_dimensions());
71 ARM_COMPUTE_RETURN_ERROR_ON(outputs.size() < 2);
73 // Start/End coordinates
74 Coordinates start_coords;
75 Coordinates end_coords;
76 for (unsigned int d = 0; d < input->num_dimensions(); ++d)
78 end_coords.set(d, -1);
80 unsigned int axis_offset = 0;
81 // Validate output tensors
82 for (const auto &output : outputs)
84 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
86 const TensorShape output_shape = output->tensor_shape();
87 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
89 const size_t axis_split_step = output_shape[split_dim];
91 // Output auto inizialitation if not yet initialized
92 TensorInfo tmp_output_info = *output->clone();
93 auto_init_if_empty(tmp_output_info,
94 input->clone()->set_is_resizable(true).set_tensor_shape(output_shape));
96 // Update coordinate on axis
97 start_coords.set(split_dim, axis_offset);
98 end_coords.set(split_dim, axis_offset + axis_split_step);
100 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(input, output, start_coords, end_coords));
102 axis_offset += axis_split_step;
108 void configure_slices(const ICLTensor *input, const std::vector<ICLTensor *> &outputs,
109 std::vector<CLSlice> &_slice_functions, uint32_t split_dim)
111 unsigned int axis_offset = 0;
112 // Start/End coordinates
113 Coordinates start_coords;
114 Coordinates end_coords;
115 for (unsigned int d = 0; d < input->info()->num_dimensions(); ++d)
117 end_coords.set(d, -1);
120 for (const auto &output : outputs)
122 const TensorShape output_shape = output->info()->tensor_shape();
123 auto op_size = output_shape.total_size();
129 assert(op_size != 0);
130 assert(split_dim <= output_shape.num_dimensions());
132 const size_t axis_split_step = output_shape[split_dim];
134 // Output auto inizialitation if not yet initialized
135 TensorInfo tmp_output_info = *output->info()->clone();
138 input->info()->clone()->set_is_resizable(true).set_tensor_shape(output_shape));
140 // Update coordinate on axis
141 start_coords.set(split_dim, axis_offset);
142 end_coords.set(split_dim, axis_offset + axis_split_step);
144 // Configure slice function
145 _slice_functions[out_iter].configure(input, output, start_coords, end_coords);
147 // Set valid region from shape
148 outputs[out_iter++]->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
149 axis_offset += axis_split_step;
155 CLSplitVEx::CLSplitVEx()
156 : _input(nullptr), _size_splits(nullptr), _outputs(), _num_splits(0), _slice_functions()
160 void CLSplitVEx::configure(const ICLTensor *input, const ICLTensor *size_splits, uint32_t split_dim,
161 const std::vector<ICLTensor *> &outputs, unsigned int num_splits)
163 ARM_COMPUTE_ERROR_ON_NULLPTR(input, size_splits);
164 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(size_splits, outputs, num_splits));
167 _size_splits = size_splits;
169 _num_splits = num_splits;
171 // Create tensor slices
172 _slice_functions.resize(_num_splits);
174 // Extract output tensor info
175 std::vector<ITensorInfo *> outputs_info;
176 for (auto &output : _outputs)
178 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
179 outputs_info.emplace_back(output->info());
183 ARM_COMPUTE_ERROR_THROW_ON(validate_slices(_input->info(), outputs_info, split_dim));
186 configure_slices(_input, _outputs, _slice_functions, split_dim);
189 void CLSplitVEx::run()
191 // execute the slices
192 for (unsigned i = 0; i < _outputs.size(); ++i)
194 _slice_functions[i].run();