2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2018-2020 Arm Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 #include "arm_compute/core/CL/kernels/CLOneHotKernel.h"
41 #include "arm_compute/core/CL/ICLTensor.h"
42 #include "arm_compute/core/CL/CLKernelLibraryEx.h"
43 #include "arm_compute/core/Error.h"
44 #include "arm_compute/core/utils/misc/ShapeCalculatorEx.h"
45 #include "support/StringSupport.h"
51 inline Status validate_arguments(const ITensorInfo *indices, const ITensorInfo *on_value,
52 const ITensorInfo *output, int depth, int axis)
54 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(indices, on_value, output);
55 const uint32_t actual_axis = wrap_around(axis, static_cast<int>(output->num_dimensions()));
56 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 4);
57 ARM_COMPUTE_RETURN_ERROR_ON(on_value->tensor_shape().total_size() != 1);
58 ARM_COMPUTE_RETURN_ERROR_ON(depth <= 0);
59 ARM_COMPUTE_RETURN_ERROR_ON(actual_axis >= output->num_dimensions());
60 ARM_COMPUTE_RETURN_ERROR_ON(on_value->data_type() == DataType::UNKNOWN);
61 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(on_value, 1, DataType::U8, DataType::S8,
62 DataType::U16, DataType::S16, DataType::F16,
63 DataType::U32, DataType::S32, DataType::F32);
64 if (output->total_size() != 0)
66 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, output);
67 TensorShape output_shape = arm_compute::misc::shape_calculator::compute_onehot_shape_ex(
68 indices->tensor_shape(), static_cast<uint32_t>(depth), actual_axis);
69 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size());
74 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *indices,
75 const ITensorInfo *on_value,
76 ITensorInfo *output, int depth, int axis)
78 ARM_COMPUTE_ERROR_ON_NULLPTR(indices, on_value, output, indices);
79 const uint32_t actual_axis = wrap_around(axis, static_cast<int>(output->num_dimensions()));
80 // Output auto initialization if not yet initialized
81 TensorShape output_shape = arm_compute::misc::shape_calculator::compute_onehot_shape_ex(
82 indices->tensor_shape(), static_cast<uint32_t>(depth), actual_axis);
83 auto_init_if_empty((*output), output_shape, 1, on_value->data_type());
85 Window win = calculate_max_window(*output, Steps());
86 output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
87 return std::make_pair(Status{}, win);
90 CLOneHotKernel::CLOneHotKernel()
91 : _indices(nullptr), _on_value(nullptr), _off_value(nullptr), _output(nullptr),
92 _is_off_value_memset(false)
95 void CLOneHotKernel::configure(const ICLTensor *indices, const ICLTensor *on_value,
96 const ICLTensor *off_value, ICLTensor *output, int depth, int axis)
98 _is_off_value_memset = false;
99 ARM_COMPUTE_ERROR_ON_NULLPTR(indices, on_value, off_value, output);
100 ARM_COMPUTE_ERROR_ON_NULLPTR(off_value->info());
101 ARM_COMPUTE_ERROR_ON(off_value->info()->tensor_shape().total_size() != 1);
102 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, off_value);
103 _off_value = off_value;
104 configure_common(indices, on_value, output, depth, axis);
106 void CLOneHotKernel::configure(const ICLTensor *indices, const ICLTensor *on_value,
107 ICLTensor *output, int depth, int axis)
109 _is_off_value_memset = true;
110 ARM_COMPUTE_ERROR_ON_NULLPTR(indices, on_value, output);
111 configure_common(indices, on_value, output, depth, axis);
113 void CLOneHotKernel::configure_common(const ICLTensor *indices, const ICLTensor *on_value,
114 ICLTensor *output, int depth, int axis)
116 ARM_COMPUTE_ERROR_THROW_ON(
117 validate_arguments(indices->info(), on_value->info(), output->info(), depth, axis));
118 // Configure kernel window
120 validate_and_configure_window(indices->info(), on_value->info(), output->info(), depth, axis);
121 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
122 if (_is_off_value_memset)
124 // Replace window with calculated by infices info
125 win_config.second = calculate_max_window(*indices->info(), Steps());
128 _on_value = on_value;
130 const auto actual_axis = wrap_around(axis, static_cast<int>(output->info()->num_dimensions()));
132 CLBuildOptions build_opts;
133 build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(
134 data_size_from_type(on_value->info()->data_type())));
135 build_opts.add_option("-DAXIS=" + support::cpp11::to_string(actual_axis));
136 build_opts.add_option("-DDEPTH=" + support::cpp11::to_string(depth));
137 build_opts.add_option("-DOUTPUT_DIM_Z=" +
138 support::cpp11::to_string(output->info()->dimension(2)));
140 const std::string kernel_name = _is_off_value_memset ? "one_hot_only_on_value" : "one_hot";
141 _kernel = static_cast<cl::Kernel>(
142 CLKernelLibraryEx::get().create_kernel(kernel_name, build_opts.options()));
143 ICLKernel::configure_internal(win_config.second);
145 Status CLOneHotKernel::validate(const ITensorInfo *indices, const ITensorInfo *on_value,
146 const ITensorInfo *off_value, const ITensorInfo *output, int depth,
149 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(off_value);
150 ARM_COMPUTE_RETURN_ERROR_ON(off_value->tensor_shape().total_size() != 1);
151 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, off_value);
152 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(indices, on_value, output, depth, axis));
153 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(indices->clone().get(),
154 on_value->clone().get(),
155 output->clone().get(), depth, axis)
159 Status CLOneHotKernel::validate(const ITensorInfo *indices, const ITensorInfo *on_value,
160 const ITensorInfo *output, int depth, int axis)
162 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(indices, on_value, output, depth, axis));
163 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(indices->clone().get(),
164 on_value->clone().get(),
165 output->clone().get(), depth, axis)
169 void CLOneHotKernel::run(const Window &window, cl::CommandQueue &queue)
171 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
172 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
173 Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
174 unsigned int idx = 0;
175 add_3D_tensor_argument(idx, _indices, window_collapsed);
176 add_1D_tensor_argument(idx, _on_value, window_collapsed);
177 if (!_is_off_value_memset)
179 add_1D_tensor_argument(idx, _off_value, window_collapsed);
181 add_4D_tensor_argument(idx, _output, window_collapsed);
182 enqueue(queue, *this, window_collapsed, lws_hint());
185 } // namespace arm_compute