Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / ARMComputeEx / src / core / CL / kernels / CLOneHotKernel.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * Copyright (c) 2018-2020 Arm Limited.
19  *
20  * SPDX-License-Identifier: MIT
21  *
22  * Permission is hereby granted, free of charge, to any person obtaining a copy
23  * of this software and associated documentation files (the "Software"), to
24  * deal in the Software without restriction, including without limitation the
25  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26  * sell copies of the Software, and to permit persons to whom the Software is
27  * furnished to do so, subject to the following conditions:
28  *
29  * The above copyright notice and this permission notice shall be included in all
30  * copies or substantial portions of the Software.
31  *
32  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38  * SOFTWARE.
39  */
40 #include "arm_compute/core/CL/kernels/CLOneHotKernel.h"
41 #include "arm_compute/core/CL/ICLTensor.h"
42 #include "arm_compute/core/CL/CLKernelLibraryEx.h"
43 #include "arm_compute/core/Error.h"
44 #include "arm_compute/core/utils/misc/ShapeCalculatorEx.h"
45 #include "support/StringSupport.h"
46 #include <string>
47 namespace arm_compute
48 {
49 namespace
50 {
51 inline Status validate_arguments(const ITensorInfo *indices, const ITensorInfo *on_value,
52                                  const ITensorInfo *output, int depth, int axis)
53 {
54   ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(indices, on_value, output);
55   const uint32_t actual_axis = wrap_around(axis, static_cast<int>(output->num_dimensions()));
56   ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 4);
57   ARM_COMPUTE_RETURN_ERROR_ON(on_value->tensor_shape().total_size() != 1);
58   ARM_COMPUTE_RETURN_ERROR_ON(depth <= 0);
59   ARM_COMPUTE_RETURN_ERROR_ON(actual_axis >= output->num_dimensions());
60   ARM_COMPUTE_RETURN_ERROR_ON(on_value->data_type() == DataType::UNKNOWN);
61   ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(on_value, 1, DataType::U8, DataType::S8,
62                                                        DataType::U16, DataType::S16, DataType::F16,
63                                                        DataType::U32, DataType::S32, DataType::F32);
64   if (output->total_size() != 0)
65   {
66     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, output);
67     TensorShape output_shape = arm_compute::misc::shape_calculator::compute_onehot_shape_ex(
68       indices->tensor_shape(), static_cast<uint32_t>(depth), actual_axis);
69     ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size());
70   }
71   return Status{};
72 }
73
74 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *indices,
75                                                         const ITensorInfo *on_value,
76                                                         ITensorInfo *output, int depth, int axis)
77 {
78   ARM_COMPUTE_ERROR_ON_NULLPTR(indices, on_value, output, indices);
79   const uint32_t actual_axis = wrap_around(axis, static_cast<int>(output->num_dimensions()));
80   // Output auto initialization if not yet initialized
81   TensorShape output_shape = arm_compute::misc::shape_calculator::compute_onehot_shape_ex(
82     indices->tensor_shape(), static_cast<uint32_t>(depth), actual_axis);
83   auto_init_if_empty((*output), output_shape, 1, on_value->data_type());
84   // Create window
85   Window win = calculate_max_window(*output, Steps());
86   output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
87   return std::make_pair(Status{}, win);
88 }
89 } // namespace
90 CLOneHotKernel::CLOneHotKernel()
91   : _indices(nullptr), _on_value(nullptr), _off_value(nullptr), _output(nullptr),
92     _is_off_value_memset(false)
93 {
94 }
95 void CLOneHotKernel::configure(const ICLTensor *indices, const ICLTensor *on_value,
96                                const ICLTensor *off_value, ICLTensor *output, int depth, int axis)
97 {
98   _is_off_value_memset = false;
99   ARM_COMPUTE_ERROR_ON_NULLPTR(indices, on_value, off_value, output);
100   ARM_COMPUTE_ERROR_ON_NULLPTR(off_value->info());
101   ARM_COMPUTE_ERROR_ON(off_value->info()->tensor_shape().total_size() != 1);
102   ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, off_value);
103   _off_value = off_value;
104   configure_common(indices, on_value, output, depth, axis);
105 }
106 void CLOneHotKernel::configure(const ICLTensor *indices, const ICLTensor *on_value,
107                                ICLTensor *output, int depth, int axis)
108 {
109   _is_off_value_memset = true;
110   ARM_COMPUTE_ERROR_ON_NULLPTR(indices, on_value, output);
111   configure_common(indices, on_value, output, depth, axis);
112 }
113 void CLOneHotKernel::configure_common(const ICLTensor *indices, const ICLTensor *on_value,
114                                       ICLTensor *output, int depth, int axis)
115 {
116   ARM_COMPUTE_ERROR_THROW_ON(
117     validate_arguments(indices->info(), on_value->info(), output->info(), depth, axis));
118   // Configure kernel window
119   auto win_config =
120     validate_and_configure_window(indices->info(), on_value->info(), output->info(), depth, axis);
121   ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
122   if (_is_off_value_memset)
123   {
124     // Replace window with calculated by infices info
125     win_config.second = calculate_max_window(*indices->info(), Steps());
126   }
127   _indices = indices;
128   _on_value = on_value;
129   _output = output;
130   const auto actual_axis = wrap_around(axis, static_cast<int>(output->info()->num_dimensions()));
131   // Set build options
132   CLBuildOptions build_opts;
133   build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(
134                                            data_size_from_type(on_value->info()->data_type())));
135   build_opts.add_option("-DAXIS=" + support::cpp11::to_string(actual_axis));
136   build_opts.add_option("-DDEPTH=" + support::cpp11::to_string(depth));
137   build_opts.add_option("-DOUTPUT_DIM_Z=" +
138                         support::cpp11::to_string(output->info()->dimension(2)));
139   // Create kernel
140   const std::string kernel_name = _is_off_value_memset ? "one_hot_only_on_value" : "one_hot";
141   _kernel = static_cast<cl::Kernel>(
142     CLKernelLibraryEx::get().create_kernel(kernel_name, build_opts.options()));
143   ICLKernel::configure_internal(win_config.second);
144 }
145 Status CLOneHotKernel::validate(const ITensorInfo *indices, const ITensorInfo *on_value,
146                                 const ITensorInfo *off_value, const ITensorInfo *output, int depth,
147                                 int axis)
148 {
149   ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(off_value);
150   ARM_COMPUTE_RETURN_ERROR_ON(off_value->tensor_shape().total_size() != 1);
151   ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, off_value);
152   ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(indices, on_value, output, depth, axis));
153   ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(indices->clone().get(),
154                                                             on_value->clone().get(),
155                                                             output->clone().get(), depth, axis)
156                                 .first);
157   return Status{};
158 }
159 Status CLOneHotKernel::validate(const ITensorInfo *indices, const ITensorInfo *on_value,
160                                 const ITensorInfo *output, int depth, int axis)
161 {
162   ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(indices, on_value, output, depth, axis));
163   ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(indices->clone().get(),
164                                                             on_value->clone().get(),
165                                                             output->clone().get(), depth, axis)
166                                 .first);
167   return Status{};
168 }
169 void CLOneHotKernel::run(const Window &window, cl::CommandQueue &queue)
170 {
171   ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
172   ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
173   Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
174   unsigned int idx = 0;
175   add_3D_tensor_argument(idx, _indices, window_collapsed);
176   add_1D_tensor_argument(idx, _on_value, window_collapsed);
177   if (!_is_off_value_memset)
178   {
179     add_1D_tensor_argument(idx, _off_value, window_collapsed);
180   }
181   add_4D_tensor_argument(idx, _output, window_collapsed);
182   enqueue(queue, *this, window_collapsed, lws_hint());
183 }
184
185 } // namespace arm_compute