2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2016-2018 ARM Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
41 #include "arm_compute/core/NEON/NEElementwiseOperationFuncs.h"
44 #include "arm_compute/core/Types.h"
45 #include "arm_compute/core/NEON/NEAsymm.h"
46 #include "arm_compute/core/ITensor.h"
47 #include "arm_compute/core/Helpers.h"
48 #include "arm_compute/core/Window.h"
53 using namespace arm_compute;
54 template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
55 void elementwise_op_templ(
56 const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
57 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
58 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &,
59 OutputScalarType *, const bool),
60 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *,
63 // Create input windows
64 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
65 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
67 // Clear X Dimension on execution window as we handle manually
69 win.set(Window::DimX, Window::Dimension(0, 1, 1));
71 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
72 const auto window_start_x = static_cast<int>(window.x().start());
73 const auto window_end_x = static_cast<int>(window.x().end());
74 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
76 if (is_broadcast_across_x)
78 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
79 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
80 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
81 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
82 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
84 // Clear X Dimension on execution window as we handle manually
85 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
87 Iterator broadcast_input(broadcast_tensor, broadcast_win);
88 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
89 Iterator output(out, win);
93 [&](const Coordinates &) {
94 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
95 const auto non_broadcast_input_ptr =
96 reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
97 const InputScalarType broadcast_value =
98 *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
101 (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr,
102 broadcast_value, output_ptr, !is_broadcast_input_2);
103 for (; x < window_end_x; ++x)
105 const auto a = *(non_broadcast_input_ptr + x);
106 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a,
107 !is_broadcast_input_2 ? a : broadcast_value);
110 broadcast_input, non_broadcast_input, output);
114 // Clear X Dimension on execution window as we handle manually
115 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
116 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
118 Iterator input1(in1, input1_win);
119 Iterator input2(in2, input2_win);
120 Iterator output(out, win);
124 [&](const Coordinates &) {
125 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
126 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
127 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
129 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr,
131 for (; x < window_end_x; ++x)
133 const auto a = *(input1_ptr + x);
134 const auto b = *(input2_ptr + x);
135 *(output_ptr + x) = (*scalar_func)(a, b);
138 input1, input2, output);
144 namespace arm_compute
147 void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
148 float (*scalar_func)(const float &, const float &),
149 int (*broadcast_func)(int, int, int, const float *, const float &, float *,
151 int (*neon_func)(int, int, int, const float *, const float *, float *))
153 elementwise_op_templ<float, float, float32x4_t>(in1, in2, out, window, scalar_func,
154 broadcast_func, neon_func);
157 void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
158 uint8_t (*scalar_func)(const uint8_t &, const uint8_t &),
159 int (*broadcast_func)(int, int, int, const uint8_t *, const uint8_t &,
160 uint8_t *, const bool),
161 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *))
163 elementwise_op_templ<uint8_t, uint8_t, uint8x16_t>(in1, in2, out, window, scalar_func,
164 broadcast_func, neon_func);
166 } // namespace arm_compute