2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2019 Arm Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 #include "arm_compute/core/NEON/kernels/NEOneHotKernel.h"
41 #include "arm_compute/core/CPP/Validate.h"
42 #include "arm_compute/core/Coordinates.h"
43 #include "arm_compute/core/Error.h"
44 #include "arm_compute/core/Helpers.h"
45 #include "arm_compute/core/IAccessWindow.h"
46 #include "arm_compute/core/TensorInfo.h"
47 #include "arm_compute/core/Validate.h"
48 #include "arm_compute/core/Window.h"
49 #include "arm_compute/core/utils/misc/ShapeCalculatorEx.h"
54 /** Validate the depth
56 * Validate that depth are not negative
58 * @param[in] depth Depth tensor.
59 * @param[in] output Output tensor.
60 * @param[in] axis Axis of depth.
62 template <typename U> void validate_depth(const ITensor *depth, const ITensor *output, int axis)
64 ARM_COMPUTE_ERROR_ON(*(reinterpret_cast<U *>(depth->buffer())) < 0);
65 ARM_COMPUTE_ERROR_ON(static_cast<U>(output->info()->tensor_shape()[axis]) !=
66 *(reinterpret_cast<U *>(depth->buffer())));
69 Status validate_arguments(const ITensorInfo *indices, const ITensorInfo *depth,
70 const ITensorInfo *on_value, const ITensorInfo *off_value,
71 const ITensorInfo *output, int axis)
73 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(indices, depth, on_value, off_value, output);
74 const int actual_axis = wrap_around(axis, static_cast<int>(output->num_dimensions()));
75 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 4);
76 ARM_COMPUTE_RETURN_ERROR_ON(on_value->tensor_shape().total_size() != 1);
77 ARM_COMPUTE_RETURN_ERROR_ON(0 > actual_axis ||
78 actual_axis >= static_cast<int>(output->num_dimensions()));
79 ARM_COMPUTE_RETURN_ERROR_ON(on_value->data_type() == DataType::UNKNOWN);
80 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(on_value, 1, DataType::U8, DataType::S8,
81 DataType::U16, DataType::S16, DataType::F16,
82 DataType::U32, DataType::S32, DataType::F32);
83 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32);
85 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, off_value);
86 if (output->total_size() != 0)
88 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, output);
94 template <typename U, typename Enable = void> bool isOnValue(U) { return true; }
96 template <typename U, std::enable_if_t<std::is_integral<U>::value, int> = 0>
97 bool isOnValue(U index, U depth)
99 return index >= 0 && index < depth;
103 NEOneHotKernel::NEOneHotKernel()
104 : _indices{nullptr}, _depth{nullptr}, _on_value{nullptr},
105 _off_value{nullptr}, _axis{-1}, _output{nullptr}, _func{}
109 template <typename U>
110 void NEOneHotKernel::onehot_0_axis(const Window &window, const ThreadInfo &info)
112 ARM_COMPUTE_UNUSED(info);
113 // Validate that the depth are not negative
114 validate_depth<U>(_depth, _output, _axis);
115 Window output_window{window};
116 output_window.set(Window::DimX, Window::Dimension(0, 1, 1));
117 Iterator output_it(_output, output_window);
118 const U off_value = *reinterpret_cast<U *>(_off_value->buffer());
121 [&](const Coordinates &id) {
122 std::fill_n(output_it.ptr(), _output->info()->dimension(0) * _output->info()->element_size(),
124 Coordinates indices_id(id);
125 indices_id.remove(0);
126 const U new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(indices_id)));
127 if (isOnValue(new_index, *(reinterpret_cast<U *>(_depth->buffer()))))
129 Coordinates onehot_id(id);
130 onehot_id.set(0, new_index);
131 std::copy_n(_on_value->buffer(), _output->info()->element_size(),
132 _output->ptr_to_element(onehot_id));
138 template <typename U>
139 inline void NEOneHotKernel::onehot_n_axis(const Window &window, const ThreadInfo &info)
141 ARM_COMPUTE_UNUSED(info);
142 // Validate that the indices are not negative
143 validate_depth<U>(_depth, _output, _axis);
144 Iterator output_it(_output, window);
147 [&](const Coordinates &id) {
148 Coordinates indices_id(id);
149 indices_id.remove(_axis);
150 const U new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(indices_id)));
151 if (isOnValue(new_index, *(reinterpret_cast<U *>(_depth->buffer()))))
153 Coordinates onehot_id(id);
154 onehot_id.set(_axis, new_index);
155 std::copy_n(static_cast<U>(id[_axis]) == new_index ? _on_value->buffer()
156 : _off_value->buffer(),
157 _output->info()->element_size(), output_it.ptr());
163 void NEOneHotKernel::configure(const ITensor *indices, const ITensor *depth,
164 const ITensor *on_value, const ITensor *off_value, ITensor *output,
167 ARM_COMPUTE_ERROR_ON_NULLPTR(indices, depth, on_value, off_value, output);
168 ARM_COMPUTE_ERROR_ON(output->info()->total_size() == 0);
169 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(indices->info(), depth->info(), on_value->info(),
170 off_value->info(), output->info(), axis));
173 _on_value = on_value;
174 _off_value = off_value;
176 _axis = wrap_around(axis, static_cast<int>(output->info()->num_dimensions()));
179 switch (_indices->info()->data_type())
182 _func = &NEOneHotKernel::onehot_0_axis<uint32_t>;
185 _func = &NEOneHotKernel::onehot_0_axis<int32_t>;
188 ARM_COMPUTE_ERROR("Not supported");
194 switch (_indices->info()->data_type())
197 _func = &NEOneHotKernel::onehot_n_axis<uint32_t>;
200 _func = &NEOneHotKernel::onehot_n_axis<int32_t>;
203 ARM_COMPUTE_ERROR("Not supported");
208 Window win = calculate_max_window(*output->info(), Steps());
209 output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
210 INEKernel::configure(win);
213 Status NEOneHotKernel::validate(const ITensorInfo *indices, const ITensorInfo *depth,
214 const ITensorInfo *on_value, const ITensorInfo *off_value,
215 const ITensorInfo *output, int axis)
217 ARM_COMPUTE_RETURN_ON_ERROR(
218 validate_arguments(indices, depth, on_value, off_value, output, axis));
222 void NEOneHotKernel::run(const Window &window, const ThreadInfo &info)
224 ARM_COMPUTE_UNUSED(info);
225 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
226 ARM_COMPUTE_ERROR_ON(_func == nullptr);
227 (this->*_func)(window, info);
229 } // namespace arm_compute