2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2019 Arm Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 #ifndef __ARM_COMPUTE_NEONEHOTKERNEL_H__
41 #define __ARM_COMPUTE_NEONEHOTKERNEL_H__
42 #include "arm_compute/core/NEON/INEKernel.h"
43 #include "arm_compute/core/Types.h"
46 // Forward declarations
48 /** Kernel to perform other operation on NEON */
49 class NEOneHotKernel : public INEKernel
52 /** Default constructor. */
54 /** Prevent instances of this class from being copied (As this class contains pointers). */
55 NEOneHotKernel(const NEOneHotKernel &) = delete;
56 /** Prevent instances of this class from being copied (As this class contains pointers). */
57 NEOneHotKernel &operator=(const NEOneHotKernel &) = delete;
58 /** Allow instances of this class to be moved. */
59 NEOneHotKernel(NEOneHotKernel &&) = default;
60 /** Allow instances of this class to be moved. */
61 NEOneHotKernel &operator=(NEOneHotKernel &&) = default;
62 /** Default detructor */
63 ~NEOneHotKernel() = default;
64 /** Name of the kernel
68 const char *name() const override { return "NEOneHotKernel"; }
69 /** Initialise the kernel's inputs and outputs
71 * @param[in] indices Indices tensor. Supported tensor rank: up to 3. Must be one of the
72 * following types: U32/S32
73 * @param[in] depth The tensor for depth of the one hot dimension.
74 * Supported tensor rank: up to 3.
75 * Must be one of the following types: U32/S32
76 * @param[in] on_value On value tensor. Supported tensor rank: only 1.
77 * Data type supported: U8/S8/U16/S16/F16/U32/S32/F32
78 * @param[in] off_value Off value tensor. Supported tensor rank: only 1.
79 * Data type supported: Same as @p on_value
80 * @param[out] output Destination tensor. Data type supported: Same as @p on_value
81 * @param[in] axis (Optional) The axis to fill. Negative values wrap around.
83 * The value must be in range [-indices.rank , indices.rank)
85 void configure(const ITensor *indices, const ITensor *depth, const ITensor *on_value,
86 const ITensor *off_value, ITensor *output, int axis = -1);
87 /** Static function to check if given info will lead to a valid configuration of @ref
90 * @param[in] indices Indices tensor info. Supported tensor rank: up to 3.
91 * Must be one of the following types: U32/S32
92 * @param[in] depth The tensor info for depth of the one hot dimension.
93 * Supported tensor rank: up to 3.
94 * Must be one of the following types: U32/S32
95 * @param[in] on_value On value tensor info. Supported tensor rank: only 1.
96 * Data type supported: U8/S8/U16/S16/F16/U32/S32/F32
97 * @param[in] off_value Off value tensor info. Supported tensor rank: only 1.
98 * Data type supported: Same as @p on_value
99 * @param[out] output Destination tensor info. Data type supported: Same as @p on_value
100 * @param[in] axis (Optional) The axis to fill. Negative values wrap around. Defaults to -1.
101 * The value must be in range [-indices.rank , indices.rank)
105 static Status validate(const ITensorInfo *indices, const ITensorInfo *depth,
106 const ITensorInfo *on_value, const ITensorInfo *off_value,
107 const ITensorInfo *output, int axis = -1);
108 // Inherited methods overridden:
109 void run(const Window &window, const ThreadInfo &info) override;
112 /** Implementation of the onehot operation for 0 axis.
114 * For onehot on the 0 axis an element by element copy is performed.
116 * @param[in] window Region on which to execute the kernel. (Must be a region of the window
117 * returned by window())
118 * @param[in] info Info about executing thread and CPU.
120 template <typename U> void onehot_0_axis(const Window &window, const ThreadInfo &info);
121 /** Implementation of the onehot operation.
123 * For 1<=axis a row-wise copy is taking place.
125 * @param[in] window Region on which to execute the kernel. (Must be a region of the window
126 * returned by window())
127 * @param[in] info Info about executing thread and CPU.
129 template <typename U> void onehot_n_axis(const Window &window, const ThreadInfo &info);
130 using kernel_ptr = void (NEOneHotKernel::*)(const Window &window, const ThreadInfo &info);
131 const ITensor *_indices;
132 const ITensor *_depth;
133 const ITensor *_on_value;
134 const ITensor *_off_value;
139 } // namespace arm_compute
140 #endif /* __ARM_COMPUTE_NEONEHOTKERNEL_H__ */