dfe5d59b08e6673d4d7535149bb6b9c7640992e8
[platform/core/ml/nnfw.git] / compute / ARMComputeEx / src / core / NEON / NEElementwiseOperationFuncs.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * Copyright (c) 2016-2018 ARM Limited.
19  *
20  * SPDX-License-Identifier: MIT
21  *
22  * Permission is hereby granted, free of charge, to any person obtaining a copy
23  * of this software and associated documentation files (the "Software"), to
24  * deal in the Software without restriction, including without limitation the
25  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26  * sell copies of the Software, and to permit persons to whom the Software is
27  * furnished to do so, subject to the following conditions:
28  *
29  * The above copyright notice and this permission notice shall be included in all
30  * copies or substantial portions of the Software.
31  *
32  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38  * SOFTWARE.
39  */
40
41 #include "arm_compute/core/NEON/NEElementwiseOperationFuncs.h"
42
43 #include <algorithm>
44 #include "arm_compute/core/Types.h"
45 #include "arm_compute/core/NEON/NEAsymm.h"
46 #include "arm_compute/core/ITensor.h"
47 #include "arm_compute/core/Helpers.h"
48 #include "arm_compute/core/Window.h"
49
50 namespace
51 {
52
53 using namespace arm_compute;
54 template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
55 void elementwise_op_templ(
56     const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
57     OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
58     int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &,
59                           OutputScalarType *, const bool),
60     int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *,
61                      OutputScalarType *))
62 {
63   // Create input windows
64   Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
65   Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
66
67   // Clear X Dimension on execution window as we handle manually
68   Window win = window;
69   win.set(Window::DimX, Window::Dimension(0, 1, 1));
70
71   const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
72   const auto window_start_x = static_cast<int>(window.x().start());
73   const auto window_end_x = static_cast<int>(window.x().end());
74   const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
75
76   if (is_broadcast_across_x)
77   {
78     const bool is_broadcast_input_2 = input2_win.x().step() == 0;
79     Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
80     Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
81     const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
82     const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
83
84     // Clear X Dimension on execution window as we handle manually
85     non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
86
87     Iterator broadcast_input(broadcast_tensor, broadcast_win);
88     Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
89     Iterator output(out, win);
90
91     execute_window_loop(win,
92                         [&](const Coordinates &) {
93                           auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
94                           const auto non_broadcast_input_ptr =
95                               reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
96                           const InputScalarType broadcast_value =
97                               *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
98
99                           int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x,
100                                                     non_broadcast_input_ptr, broadcast_value,
101                                                     output_ptr, !is_broadcast_input_2);
102                           for (; x < window_end_x; ++x)
103                           {
104                             const auto a = *(non_broadcast_input_ptr + x);
105                             *(output_ptr + x) =
106                                 (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a,
107                                                !is_broadcast_input_2 ? a : broadcast_value);
108                           }
109                         },
110                         broadcast_input, non_broadcast_input, output);
111   }
112   else
113   {
114     // Clear X Dimension on execution window as we handle manually
115     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
116     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
117
118     Iterator input1(in1, input1_win);
119     Iterator input2(in2, input2_win);
120     Iterator output(out, win);
121
122     execute_window_loop(win,
123                         [&](const Coordinates &) {
124                           auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
125                           const auto input1_ptr =
126                               reinterpret_cast<const InputScalarType *>(input1.ptr());
127                           const auto input2_ptr =
128                               reinterpret_cast<const InputScalarType *>(input2.ptr());
129
130                           int x = (*neon_func)(window_start_x, window_end_x, window_step_x,
131                                                input1_ptr, input2_ptr, output_ptr);
132                           for (; x < window_end_x; ++x)
133                           {
134                             const auto a = *(input1_ptr + x);
135                             const auto b = *(input2_ptr + x);
136                             *(output_ptr + x) = (*scalar_func)(a, b);
137                           }
138                         },
139                         input1, input2, output);
140   }
141 }
142
143 } // namespace
144
145 namespace arm_compute
146 {
147
148 void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
149                     float (*scalar_func)(const float &, const float &),
150                     int (*broadcast_func)(int, int, int, const float *, const float &, float *,
151                                           const bool),
152                     int (*neon_func)(int, int, int, const float *, const float *, float *))
153 {
154   elementwise_op_templ<float, float, float32x4_t>(in1, in2, out, window, scalar_func,
155                                                   broadcast_func, neon_func);
156 }
157
158 void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
159                     uint8_t (*scalar_func)(const uint8_t &, const uint8_t &),
160                     int (*broadcast_func)(int, int, int, const uint8_t *, const uint8_t &,
161                                           uint8_t *, const bool),
162                     int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *))
163 {
164   elementwise_op_templ<uint8_t, uint8_t, uint8x16_t>(in1, in2, out, window, scalar_func,
165                                                      broadcast_func, neon_func);
166 }
167 } // namespace arm_compute