Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / ARMComputeEx / src / core / NEON / kernels / NEOneHotKernel.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * Copyright (c) 2019 Arm Limited.
19  *
20  * SPDX-License-Identifier: MIT
21  *
22  * Permission is hereby granted, free of charge, to any person obtaining a copy
23  * of this software and associated documentation files (the "Software"), to
24  * deal in the Software without restriction, including without limitation the
25  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26  * sell copies of the Software, and to permit persons to whom the Software is
27  * furnished to do so, subject to the following conditions:
28  *
29  * The above copyright notice and this permission notice shall be included in all
30  * copies or substantial portions of the Software.
31  *
32  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38  * SOFTWARE.
39  */
40 #include "arm_compute/core/NEON/kernels/NEOneHotKernel.h"
41 #include "arm_compute/core/CPP/Validate.h"
42 #include "arm_compute/core/Coordinates.h"
43 #include "arm_compute/core/Error.h"
44 #include "arm_compute/core/Helpers.h"
45 #include "arm_compute/core/IAccessWindow.h"
46 #include "arm_compute/core/TensorInfo.h"
47 #include "arm_compute/core/Validate.h"
48 #include "arm_compute/core/Window.h"
49 #include "arm_compute/core/utils/misc/ShapeCalculatorEx.h"
50 namespace arm_compute
51 {
52 namespace
53 {
54 /** Validate the depth
55  *
56  * Validate that depth are not negative
57  *
58  * @param[in] depth Depth tensor.
59  * @param[in] output Output tensor.
60  * @param[in] axis Axis of depth.
61  */
62 template <typename U> void validate_depth(const ITensor *depth, const ITensor *output, int axis)
63 {
64   ARM_COMPUTE_ERROR_ON(*(reinterpret_cast<U *>(depth->buffer())) < 0);
65   ARM_COMPUTE_ERROR_ON(static_cast<U>(output->info()->tensor_shape()[axis]) !=
66                        *(reinterpret_cast<U *>(depth->buffer())));
67 }
68
69 Status validate_arguments(const ITensorInfo *indices, const ITensorInfo *depth,
70                           const ITensorInfo *on_value, const ITensorInfo *off_value,
71                           const ITensorInfo *output, int axis)
72 {
73   ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(indices, depth, on_value, off_value, output);
74   const int actual_axis = wrap_around(axis, static_cast<int>(output->num_dimensions()));
75   ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 4);
76   ARM_COMPUTE_RETURN_ERROR_ON(on_value->tensor_shape().total_size() != 1);
77   ARM_COMPUTE_RETURN_ERROR_ON(0 > actual_axis ||
78                               actual_axis >= static_cast<int>(output->num_dimensions()));
79   ARM_COMPUTE_RETURN_ERROR_ON(on_value->data_type() == DataType::UNKNOWN);
80   ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(on_value, 1, DataType::U8, DataType::S8,
81                                                        DataType::U16, DataType::S16, DataType::F16,
82                                                        DataType::U32, DataType::S32, DataType::F32);
83   ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32);
84
85   ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, off_value);
86   if (output->total_size() != 0)
87   {
88     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, output);
89   }
90
91   return Status{};
92 }
93
94 template <typename U, typename Enable = void> bool isOnValue(U) { return true; }
95
96 template <typename U, std::enable_if_t<std::is_integral<U>::value, int> = 0>
97 bool isOnValue(U index, U depth)
98 {
99   return index >= 0 && index < depth;
100 }
101 } // namespace
102
103 NEOneHotKernel::NEOneHotKernel()
104   : _indices{nullptr}, _depth{nullptr}, _on_value{nullptr},
105     _off_value{nullptr}, _axis{-1}, _output{nullptr}, _func{}
106 {
107 }
108
109 template <typename U>
110 void NEOneHotKernel::onehot_0_axis(const Window &window, const ThreadInfo &info)
111 {
112   ARM_COMPUTE_UNUSED(info);
113   // Validate that the depth are not negative
114   validate_depth<U>(_depth, _output, _axis);
115   Window output_window{window};
116   output_window.set(Window::DimX, Window::Dimension(0, 1, 1));
117   Iterator output_it(_output, output_window);
118   const U off_value = *reinterpret_cast<U *>(_off_value->buffer());
119   execute_window_loop(
120     output_window,
121     [&](const Coordinates &id) {
122       std::fill_n(output_it.ptr(), _output->info()->dimension(0) * _output->info()->element_size(),
123                   off_value);
124       Coordinates indices_id(id);
125       indices_id.remove(0);
126       const U new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(indices_id)));
127       if (isOnValue(new_index, *(reinterpret_cast<U *>(_depth->buffer()))))
128       {
129         Coordinates onehot_id(id);
130         onehot_id.set(0, new_index);
131         std::copy_n(_on_value->buffer(), _output->info()->element_size(),
132                     _output->ptr_to_element(onehot_id));
133       }
134     },
135     output_it);
136 }
137
138 template <typename U>
139 inline void NEOneHotKernel::onehot_n_axis(const Window &window, const ThreadInfo &info)
140 {
141   ARM_COMPUTE_UNUSED(info);
142   // Validate that the indices are not negative
143   validate_depth<U>(_depth, _output, _axis);
144   Iterator output_it(_output, window);
145   execute_window_loop(
146     window,
147     [&](const Coordinates &id) {
148       Coordinates indices_id(id);
149       indices_id.remove(_axis);
150       const U new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(indices_id)));
151       if (isOnValue(new_index, *(reinterpret_cast<U *>(_depth->buffer()))))
152       {
153         Coordinates onehot_id(id);
154         onehot_id.set(_axis, new_index);
155         std::copy_n(static_cast<U>(id[_axis]) == new_index ? _on_value->buffer()
156                                                            : _off_value->buffer(),
157                     _output->info()->element_size(), output_it.ptr());
158       }
159     },
160     output_it);
161 }
162
163 void NEOneHotKernel::configure(const ITensor *indices, const ITensor *depth,
164                                const ITensor *on_value, const ITensor *off_value, ITensor *output,
165                                int axis)
166 {
167   ARM_COMPUTE_ERROR_ON_NULLPTR(indices, depth, on_value, off_value, output);
168   ARM_COMPUTE_ERROR_ON(output->info()->total_size() == 0);
169   ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(indices->info(), depth->info(), on_value->info(),
170                                                 off_value->info(), output->info(), axis));
171   _indices = indices;
172   _depth = depth;
173   _on_value = on_value;
174   _off_value = off_value;
175   _output = output;
176   _axis = wrap_around(axis, static_cast<int>(output->info()->num_dimensions()));
177   if (0 == _axis)
178   {
179     switch (_indices->info()->data_type())
180     {
181       case DataType::U32:
182         _func = &NEOneHotKernel::onehot_0_axis<uint32_t>;
183         break;
184       case DataType::S32:
185         _func = &NEOneHotKernel::onehot_0_axis<int32_t>;
186         break;
187       default:
188         ARM_COMPUTE_ERROR("Not supported");
189         break;
190     }
191   }
192   else
193   {
194     switch (_indices->info()->data_type())
195     {
196       case DataType::U32:
197         _func = &NEOneHotKernel::onehot_n_axis<uint32_t>;
198         break;
199       case DataType::S32:
200         _func = &NEOneHotKernel::onehot_n_axis<int32_t>;
201         break;
202       default:
203         ARM_COMPUTE_ERROR("Not supported");
204         break;
205     }
206   }
207   // Create window
208   Window win = calculate_max_window(*output->info(), Steps());
209   output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
210   INEKernel::configure(win);
211 }
212
213 Status NEOneHotKernel::validate(const ITensorInfo *indices, const ITensorInfo *depth,
214                                 const ITensorInfo *on_value, const ITensorInfo *off_value,
215                                 const ITensorInfo *output, int axis)
216 {
217   ARM_COMPUTE_RETURN_ON_ERROR(
218     validate_arguments(indices, depth, on_value, off_value, output, axis));
219   return Status{};
220 }
221
222 void NEOneHotKernel::run(const Window &window, const ThreadInfo &info)
223 {
224   ARM_COMPUTE_UNUSED(info);
225   ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
226   ARM_COMPUTE_ERROR_ON(_func == nullptr);
227   (this->*_func)(window, info);
228 }
229 } // namespace arm_compute