--- /dev/null
+/* NEON implementation of sin, cos, exp and log
+ *
+ * Inspired by Intel Approximate Math library, and based on the
+ * corresponding algorithms of the cephes math library
+ */
+
+/* Copyright (C) 2011 Julien Pommier
+ *
+ * This software is provided 'as-is', without any express or implied
+ * warranty. In no event will the authors be held liable for any damages
+ * arising from the use of this software.
+ *
+ * Permission is granted to anyone to use this software for any purpose,
+ * including commercial applications, and to alter it and redistribute it
+ * freely, subject to the following restrictions:
+ *
+ * 1. The origin of this software must not be misrepresented; you must not
+ * claim that you wrote the original software. If you use this software
+ * in a product, an acknowledgment in the product documentation would be
+ * appreciated but is not required.
+ * 2. Altered source versions must be plainly marked as such, and must not be
+ * misrepresented as being the original software.
+ * 3. This notice may not be removed or altered from any source distribution.
+ *
+ * (this is the zlib license)
+ */
+
+#include <arm_neon.h>
+
+#define c_inv_mant_mask ~0x7f800000u
+#define c_cephes_SQRTHF 0.707106781186547524
+#define c_cephes_log_p0 7.0376836292E-2
+#define c_cephes_log_p1 -1.1514610310E-1
+#define c_cephes_log_p2 1.1676998740E-1
+#define c_cephes_log_p3 -1.2420140846E-1
+#define c_cephes_log_p4 +1.4249322787E-1
+#define c_cephes_log_p5 -1.6668057665E-1
+#define c_cephes_log_p6 +2.0000714765E-1
+#define c_cephes_log_p7 -2.4999993993E-1
+#define c_cephes_log_p8 +3.3333331174E-1
+#define c_cephes_log_q1 -2.12194440e-4
+#define c_cephes_log_q2 0.693359375
+
+/* natural logarithm computed for 4 simultaneous float
+ * return NaN for x <= 0
+ */
+static inline float32x4_t log_ps(float32x4_t x)
+{
+ float32x4_t one = vdupq_n_f32(1);
+
+ x = vmaxq_f32(x, vdupq_n_f32(0)); /* force flush to zero on denormal values */
+ uint32x4_t invalid_mask = vcleq_f32(x, vdupq_n_f32(0));
+
+ int32x4_t ux = vreinterpretq_s32_f32(x);
+
+ int32x4_t emm0 = vshrq_n_s32(ux, 23);
+
+ /* keep only the fractional part */
+ ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask));
+ ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f)));
+ x = vreinterpretq_f32_s32(ux);
+
+ emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f));
+ float32x4_t e = vcvtq_f32_s32(emm0);
+
+ e = vaddq_f32(e, one);
+
+ /* part2:
+ * if( x < SQRTHF ) {
+ * e -= 1;
+ * x = x + x - 1.0;
+ * } else { x = x - 1.0; }
+ */
+ uint32x4_t mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF));
+ float32x4_t tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
+ x = vsubq_f32(x, one);
+ e = vsubq_f32(e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask)));
+ x = vaddq_f32(x, tmp);
+
+ float32x4_t z = vmulq_f32(x, x);
+
+ float32x4_t y = vdupq_n_f32(c_cephes_log_p0);
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1));
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2));
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3));
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4));
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5));
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6));
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7));
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8));
+ y = vmulq_f32(y, x);
+
+ y = vmulq_f32(y, z);
+
+ tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1));
+ y = vaddq_f32(y, tmp);
+
+ tmp = vmulq_f32(z, vdupq_n_f32(0.5f));
+ y = vsubq_f32(y, tmp);
+
+ tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2));
+ x = vaddq_f32(x, y);
+ x = vaddq_f32(x, tmp);
+ x = vreinterpretq_f32_u32(
+ vorrq_u32(vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN
+ return x;
+}
+
+#define c_exp_hi 88.3762626647949f
+#define c_exp_lo -88.3762626647949f
+
+#define c_cephes_LOG2EF 1.44269504088896341
+#define c_cephes_exp_C1 0.693359375
+#define c_cephes_exp_C2 -2.12194440e-4
+
+#define c_cephes_exp_p0 1.9875691500E-4
+#define c_cephes_exp_p1 1.3981999507E-3
+#define c_cephes_exp_p2 8.3334519073E-3
+#define c_cephes_exp_p3 4.1665795894E-2
+#define c_cephes_exp_p4 1.6666665459E-1
+#define c_cephes_exp_p5 5.0000001201E-1
+
+/* exp() computed for 4 float at once */
+static inline float32x4_t exp_ps(float32x4_t x)
+{
+ float32x4_t tmp, fx;
+
+ float32x4_t one = vdupq_n_f32(1);
+ x = vminq_f32(x, vdupq_n_f32(c_exp_hi));
+ x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo));
+
+ /* express exp(x) as exp(g + n*log(2)) */
+ fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF));
+
+ /* perform a floorf */
+ tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx));
+
+ /* if greater, substract 1 */
+ uint32x4_t mask = vcgtq_f32(tmp, fx);
+ mask = vandq_u32(mask, vreinterpretq_u32_f32(one));
+
+ fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask));
+
+ tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1));
+ float32x4_t z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2));
+ x = vsubq_f32(x, tmp);
+ x = vsubq_f32(x, z);
+
+ static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2,
+ c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5};
+ float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0);
+ float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1);
+ float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2);
+ float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3);
+ float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4);
+ float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5);
+
+ y = vmulq_f32(y, x);
+ z = vmulq_f32(x, x);
+
+ y = vaddq_f32(y, c1);
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, c2);
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, c3);
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, c4);
+ y = vmulq_f32(y, x);
+ y = vaddq_f32(y, c5);
+
+ y = vmulq_f32(y, z);
+ y = vaddq_f32(y, x);
+ y = vaddq_f32(y, one);
+
+ /* build 2^n */
+ int32x4_t mm;
+ mm = vcvtq_s32_f32(fx);
+ mm = vaddq_s32(mm, vdupq_n_s32(0x7f));
+ mm = vshlq_n_s32(mm, 23);
+ float32x4_t pow2n = vreinterpretq_f32_s32(mm);
+
+ y = vmulq_f32(y, pow2n);
+ return y;
+}
+
+#define c_minus_cephes_DP1 -0.78515625
+#define c_minus_cephes_DP2 -2.4187564849853515625e-4
+#define c_minus_cephes_DP3 -3.77489497744594108e-8
+#define c_sincof_p0 -1.9515295891E-4
+#define c_sincof_p1 8.3321608736E-3
+#define c_sincof_p2 -1.6666654611E-1
+#define c_coscof_p0 2.443315711809948E-005
+#define c_coscof_p1 -1.388731625493765E-003
+#define c_coscof_p2 4.166664568298827E-002
+#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI
+
+/* evaluation of 4 sines & cosines at once.
+ *
+ * The code is the exact rewriting of the cephes sinf function.
+ * Precision is excellent as long as x < 8192 (I did not bother to
+ * take into account the special handling they have for greater values
+ * -- it does not return garbage for arguments over 8192, though, but
+ * the extra precision is missing).
+ *
+ * Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
+ * surprising but correct result.
+ *
+ * Note also that when you compute sin(x), cos(x) is available at
+ * almost no extra price so both sin_ps and cos_ps make use of
+ * sincos_ps..
+ */
+static inline void sincos_ps(float32x4_t x, float32x4_t *ysin, float32x4_t *ycos)
+{
+ // any x
+ float32x4_t xmm1, xmm2, xmm3, y;
+
+ uint32x4_t emm2;
+
+ uint32x4_t sign_mask_sin, sign_mask_cos;
+ sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0));
+ x = vabsq_f32(x);
+
+ /* scale by 4/Pi */
+ y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI));
+
+ /* store the integer part of y in mm0 */
+ emm2 = vcvtq_u32_f32(y);
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ emm2 = vaddq_u32(emm2, vdupq_n_u32(1));
+ emm2 = vandq_u32(emm2, vdupq_n_u32(~1));
+ y = vcvtq_f32_u32(emm2);
+
+ /* get the polynom selection mask
+ * there is one polynom for 0 <= x <= Pi/4
+ * and another one for Pi/4<x<=Pi/2
+ *
+ * Both branches will be computed.
+ */
+ uint32x4_t poly_mask = vtstq_u32(emm2, vdupq_n_u32(2));
+
+ /* The magic pass: "Extended precision modular arithmetic"
+ * x = ((x - y * DP1) - y * DP2) - y * DP3; */
+ xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1);
+ xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2);
+ xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3);
+ x = vaddq_f32(x, xmm1);
+ x = vaddq_f32(x, xmm2);
+ x = vaddq_f32(x, xmm3);
+
+ sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4)));
+ sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4));
+
+ /* Evaluate the first polynom (0 <= x <= Pi/4) in y1,
+ * and the second polynom (Pi/4 <= x <= 0) in y2 */
+ float32x4_t z = vmulq_f32(x, x);
+ float32x4_t y1, y2;
+
+ y1 = vmulq_n_f32(z, c_coscof_p0);
+ y2 = vmulq_n_f32(z, c_sincof_p0);
+ y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1));
+ y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1));
+ y1 = vmulq_f32(y1, z);
+ y2 = vmulq_f32(y2, z);
+ y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2));
+ y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2));
+ y1 = vmulq_f32(y1, z);
+ y2 = vmulq_f32(y2, z);
+ y1 = vmulq_f32(y1, z);
+ y2 = vmulq_f32(y2, x);
+ y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f)));
+ y2 = vaddq_f32(y2, x);
+ y1 = vaddq_f32(y1, vdupq_n_f32(1));
+
+ /* select the correct result from the two polynoms */
+ float32x4_t ys = vbslq_f32(poly_mask, y1, y2);
+ float32x4_t yc = vbslq_f32(poly_mask, y2, y1);
+ *ysin = vbslq_f32(sign_mask_sin, vnegq_f32(ys), ys);
+ *ycos = vbslq_f32(sign_mask_cos, yc, vnegq_f32(yc));
+}
+
+static inline float32x4_t sin_ps(float32x4_t x)
+{
+ float32x4_t ysin, ycos;
+ sincos_ps(x, &ysin, &ycos);
+ return ysin;
+}
+
+static inline float32x4_t cos_ps(float32x4_t x)
+{
+ float32x4_t ysin, ycos;
+ sincos_ps(x, &ysin, &ycos);
+ return ycos;
+}
+
+static inline float32x4_t div_ps(float32x4_t a, float32x4_t b)
+{
+ float32x4_t reciprocal = vrecpeq_f32(b);
+ reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
+ // reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
+ return vmulq_f32(a, reciprocal);
+}
+
+static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b)
+{
+ // pow(x, m) = exp(m * log(x))
+ return exp_ps(vmulq_f32(b, log_ps(a)));
+}
--- /dev/null
+// Tencent is pleased to support the open source community by making ncnn available.
+//
+// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
+// in compliance with the License. You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software distributed
+// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+// CONDITIONS OF ANY KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations under the License.
+
+#include "ncnn/layer/binaryop.h"
+#include <math.h>
+#include <algorithm>
+#include <functional>
+#include <sys/time.h>
+
+#if __ARM_NEON
+#include <arm_neon.h>
+#include "arm/neon_mathfun.h"
+#endif // __ARM_NEON
+
+namespace nnfw
+{
+namespace ncnn
+{
+
+template <typename Op> static int binary_op(const Mat &a, const Mat &b, Mat &c)
+{
+ Op op;
+
+ int w = a.w;
+ int h = a.h;
+ int channels = a.c;
+ int size = w * h;
+
+ int w1 = b.w;
+ int h1 = b.h;
+ int channels1 = b.c;
+ int size1 = w1 * h1;
+
+ if (a.dims == 3)
+ {
+ c.create(w, h, channels);
+ if (c.empty())
+ return -100;
+
+ if (b.dims == 3)
+ {
+ if (b.w == 1 && b.h == 1)
+ {
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = a.channel(q);
+ const float *ptr1 = b.channel(q);
+ float *outptr = c.channel(q);
+
+ float tt = *ptr1;
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = op(ptr[i], tt);
+ }
+ }
+
+ return 0;
+ }
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = a.channel(q);
+ const float *ptr1 = b.channel(q);
+ float *outptr = c.channel(q);
+
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = op(ptr[i], ptr1[i]);
+ }
+ }
+
+ return 0;
+ }
+
+ if (b.dims == 2)
+ {
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = a.channel(q);
+ const float *ptr1 = (const float *)b + h * q;
+ float *outptr = c.channel(q);
+
+ for (int y = 0; y < h; y++)
+ {
+ const float b0 = ptr1[y];
+ for (int x = 0; x < w; x++)
+ {
+ outptr[x] = op(ptr[x], b0);
+ }
+
+ ptr += w;
+ outptr += w;
+ }
+ }
+
+ return 0;
+ }
+
+ if (b.dims == 1)
+ {
+ if (b.w == 1)
+ {
+ const float b0 = b[0];
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = a.channel(q);
+ float *outptr = c.channel(q);
+
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = op(ptr[i], b0);
+ }
+ }
+
+ return 0;
+ }
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = a.channel(q);
+ const float b0 = b[q];
+ float *outptr = c.channel(q);
+
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = op(ptr[i], b0);
+ }
+ }
+
+ return 0;
+ }
+ }
+ else if (a.dims == 2)
+ {
+ if (b.dims == 3)
+ {
+ c.create(w1, h1, channels1);
+ if (c.empty())
+ return -100;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float *ptr = (const float *)a + h1 * q;
+ const float *ptr1 = b.channel(q);
+ float *outptr = c.channel(q);
+
+ for (int y = 0; y < h1; y++)
+ {
+ const float a0 = ptr[y];
+ for (int x = 0; x < w1; x++)
+ {
+ outptr[x] = op(a0, ptr1[x]);
+ }
+
+ ptr1 += w1;
+ outptr += w1;
+ }
+ }
+
+ return 0;
+ }
+
+ c.create(w, h);
+ if (c.empty())
+ return -100;
+
+ if (b.dims == 2)
+ {
+ for (int i = 0; i < size; i++)
+ {
+ c[i] = op(a[i], b[i]);
+ }
+
+ return 0;
+ }
+
+ if (b.dims == 1)
+ {
+ c.create(w, h);
+ if (c.empty())
+ return -100;
+
+ if (b.w == 1)
+ {
+ const float b0 = b[0];
+ for (int i = 0; i < size; i++)
+ {
+ c[i] = op(a[i], b0);
+ }
+
+ return 0;
+ }
+
+ const float *ptr = a;
+ float *outptr = c;
+
+ for (int y = 0; y < h; y++)
+ {
+ const float b0 = b[y];
+ for (int x = 0; x < w; x++)
+ {
+ outptr[x] = op(ptr[x], b0);
+ }
+
+ ptr += w;
+ outptr += w;
+ }
+
+ return 0;
+ }
+ }
+ else if (a.dims == 1)
+ {
+ if (a.w == 1)
+ {
+ if (b.dims == 3)
+ {
+ c.create(w1, h1, channels1);
+ if (c.empty())
+ return -100;
+
+ const float a0 = a[0];
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float *ptr1 = b.channel(q);
+ float *outptr = c.channel(q);
+
+ for (int i = 0; i < size1; i++)
+ {
+ outptr[i] = op(a0, ptr1[i]);
+ }
+ }
+
+ return 0;
+ }
+
+ if (b.dims == 2)
+ {
+ c.create(w1, h1);
+ if (c.empty())
+ return -100;
+
+ const float a0 = a[0];
+ for (int i = 0; i < size1; i++)
+ {
+ c[i] = op(a0, b[i]);
+ }
+
+ return 0;
+ }
+
+ if (b.dims == 1)
+ {
+ c.create(w1);
+ if (c.empty())
+ return -100;
+
+ const float a0 = a[0];
+ for (int i = 0; i < size1; i++)
+ {
+ c[i] = op(a0, b[i]);
+ }
+
+ return 0;
+ }
+ }
+
+ if (b.dims == 3)
+ {
+ c.create(w1, h1, channels1);
+ if (c.empty())
+ return -100;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float a0 = a[q];
+ const float *ptr1 = b.channel(q);
+ float *outptr = c.channel(q);
+
+ for (int i = 0; i < size1; i++)
+ {
+ outptr[i] = op(a0, ptr1[i]);
+ }
+ }
+
+ return 0;
+ }
+
+ if (b.dims == 2)
+ {
+ c.create(w1, h1);
+ if (c.empty())
+ return -100;
+
+ const float *ptr1 = b;
+ float *outptr = c;
+
+ for (int y = 0; y < h1; y++)
+ {
+ const float a0 = a[y];
+ for (int x = 0; x < w1; x++)
+ {
+ outptr[x] = op(a0, ptr1[x]);
+ }
+
+ ptr1 += w1;
+ outptr += w1;
+ }
+
+ return 0;
+ }
+
+ if (b.dims == 1)
+ {
+ c.create(w);
+ if (c.empty())
+ return -100;
+
+ if (b.w == 1)
+ {
+ const float b0 = b[0];
+ for (int i = 0; i < size; i++)
+ {
+ c[i] = op(a[i], b0);
+ }
+
+ return 0;
+ }
+
+ for (int i = 0; i < size; i++)
+ {
+ c[i] = op(a[i], b[i]);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <typename Op> static int binary_op_scalar_inplace(Mat &a, float b)
+{
+ Op op;
+
+ int w = a.w;
+ int h = a.h;
+ int channels = a.c;
+ int size = w * h;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ float *ptr = a.channel(q);
+
+ for (int i = 0; i < size; i++)
+ {
+ ptr[i] = op(ptr[i], b);
+ }
+ }
+
+ return 0;
+}
+
+template <typename T> struct binary_op_max : std::binary_function<T, T, T>
+{
+ T operator()(const T &x, const T &y) const { return std::max(x, y); }
+};
+
+template <typename T> struct binary_op_min : std::binary_function<T, T, T>
+{
+ T operator()(const T &x, const T &y) const { return std::min(x, y); }
+};
+
+template <typename T> struct binary_op_pow : std::binary_function<T, T, T>
+{
+ T operator()(const T &x, const T &y) const { return pow(x, y); }
+};
+
+template <typename T> struct binary_op_SquaredDifference : std::binary_function<T, T, T>
+{
+ T operator()(const T &x, const T &y) const { return pow((x - y), 2); }
+};
+
+int ncnn_binary_op(const BinaryOpParam ¶m, const Mat &bottom_blob, const Mat &bottom_blob1,
+ Mat &top_blob)
+{
+ int ret = 0;
+ auto op_type = param.op_type;
+ auto b = param.b;
+
+ // Only support add operation, none broadcasting
+ // Other case, need to remove internal memory allocation and check correctness
+ if (op_type != BinaryOp::Operation_ADD)
+ {
+ throw std::runtime_error{"NYI: Only support ADD operation"};
+ }
+ if (bottom_blob.dims != bottom_blob1.dims)
+ {
+ throw std::runtime_error{"NYI: Cannot use broadcasting"};
+ }
+
+// printf("-------------------BinaryOp---------------\n");
+
+// printf("op_type = %d, ", op_type);
+// printf("in1: (%d, %d, %d), dims = %d, ", bottom_blob.w, bottom_blob.h, bottom_blob.c,
+// bottom_blob.dims);
+// printf("in2: (%d, %d, %d), dims = %d\n", bottom_blob1.w, bottom_blob1.h, bottom_blob1.c,
+// bottom_blob1.dims);
+
+#if __ARM_NEON
+ int w = bottom_blob.w;
+ int h = bottom_blob.h;
+ int channels = bottom_blob.c;
+ int size = w * h;
+
+ int w1 = bottom_blob1.w;
+ int h1 = bottom_blob1.h;
+ int channels1 = bottom_blob1.c;
+ int size1 = w1 * h1;
+
+ if (op_type == BinaryOp::Operation_ADD)
+ {
+ if (bottom_blob.dims == 3 && bottom_blob1.dims == 3)
+ {
+ // Fix for nnfw: disable allocation for output
+ // top_blob.create(w, h, channels);
+ if (bottom_blob1.w == 1 && bottom_blob1.h == 1)
+ {
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+#if __ARM_NEON
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+ float tt = *ptr1;
+
+ float32x4_t _p2 = vdupq_n_f32(tt);
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+
+ _p1 = vaddq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 + tt);
+ in1++;
+ out++;
+ }
+
+#else
+ float tt = *ptr1;
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = (ptr[i] + tt);
+ }
+#endif
+ }
+
+ ret = 0;
+ }
+ else
+ {
+ if (size * bottom_blob.elemsize % 16 != 0)
+ {
+ throw std::runtime_error{"Unmatched alignment"};
+ }
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *in2 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vld1q_f32(in2);
+
+ _p1 = vaddq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ in2 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = *in1 + *in2;
+ in1++;
+ in2++;
+ out++;
+ }
+ }
+ }
+ }
+ else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1)
+ {
+ top_blob.create(w, h, channels);
+ if (bottom_blob1.w == 1)
+ {
+ ret = binary_op<std::plus<float>>(bottom_blob, bottom_blob1, top_blob);
+ // return ret;
+ goto out;
+ }
+ float *pt = (float *)bottom_blob1.data;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float b0 = pt[q];
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vdupq_n_f32(b0);
+
+ _p1 = vaddq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 + b0);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w1, h1, channels1);
+ if (top_blob.empty())
+ return -100;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float a0 = bottom_blob[q];
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size1 >> 2;
+ int remain = size1 - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vdupq_n_f32(a0);
+ float32x4_t _p2 = vld1q_f32(in1);
+
+ _p1 = vaddq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (a0 + *in1);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else
+ ret = binary_op<std::plus<float>>(bottom_blob, bottom_blob1, top_blob);
+ }
+
+ if (op_type == BinaryOp::Operation_SUB)
+ {
+ if (bottom_blob.dims == 3 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w, h, channels);
+
+ if (bottom_blob1.w == 1 && bottom_blob1.h == 1)
+ {
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+#if __ARM_NEON
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+ float tt = *ptr1;
+
+ float32x4_t _p2 = vdupq_n_f32(tt);
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+
+ _p1 = vsubq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 - tt);
+ in1++;
+ out++;
+ }
+
+#else
+ float tt = *ptr1;
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = (ptr[i] - tt);
+ }
+#endif
+ }
+
+ ret = 0;
+ }
+ else
+ {
+ top_blob.create(w, h, channels);
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *in2 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vld1q_f32(in2);
+
+ _p1 = vsubq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ in2 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = *in1 - *in2;
+ in1++;
+ in2++;
+ out++;
+ }
+ }
+ }
+ }
+ else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1)
+ {
+ top_blob.create(w, h, channels);
+ if (bottom_blob1.w == 1)
+ {
+ ret = binary_op<std::minus<float>>(bottom_blob, bottom_blob1, top_blob);
+ // return ret;
+ goto out;
+ }
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float b0 = bottom_blob1[q];
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vdupq_n_f32(b0);
+
+ _p1 = vsubq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 - b0);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w1, h1, channels1);
+ if (top_blob.empty())
+ return -100;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float a0 = bottom_blob[q];
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size1 >> 2;
+ int remain = size1 - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vdupq_n_f32(a0);
+ float32x4_t _p2 = vld1q_f32(in1);
+
+ _p1 = vsubq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (a0 - *in1);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else
+ ret = binary_op<std::minus<float>>(bottom_blob, bottom_blob1, top_blob);
+ }
+
+ if (op_type == BinaryOp::Operation_MUL)
+ {
+ if (bottom_blob.dims == 3 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w, h, channels);
+
+ if (bottom_blob1.w == 1 && bottom_blob1.h == 1)
+ {
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+#if __ARM_NEON
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+ float tt = *ptr1;
+
+ float32x4_t _p2 = vdupq_n_f32(tt);
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+
+ _p1 = vmulq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 * tt);
+ in1++;
+ out++;
+ }
+
+#else
+ float tt = *ptr1;
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = (ptr[i] * tt);
+ }
+#endif
+ }
+
+ ret = 0;
+ }
+ else
+ {
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *in2 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vld1q_f32(in2);
+
+ _p1 = vmulq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ in2 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = *in1 * *in2;
+ in1++;
+ in2++;
+ out++;
+ }
+ }
+ }
+ }
+ else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1)
+ {
+ top_blob.create(w, h, channels);
+ if (bottom_blob1.w == 1)
+ {
+ ret = binary_op<std::multiplies<float>>(bottom_blob, bottom_blob1, top_blob);
+ // return ret;
+ goto out;
+ }
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float b0 = bottom_blob1[q];
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vdupq_n_f32(b0);
+
+ _p1 = vmulq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 * b0);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w1, h1, channels1);
+ if (top_blob.empty())
+ return -100;
+
+ if (bottom_blob.w != bottom_blob1.c)
+ {
+ ret = binary_op<std::multiplies<float>>(bottom_blob, bottom_blob1, top_blob);
+ goto out;
+ }
+
+ float *pt = (float *)bottom_blob.data;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float a0 = pt[q];
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size1 >> 2;
+ int remain = size1 - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vdupq_n_f32(a0);
+ float32x4_t _p2 = vld1q_f32(in1);
+
+ _p1 = vmulq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (a0 * *in1);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else
+ ret = binary_op<std::multiplies<float>>(bottom_blob, bottom_blob1, top_blob);
+ }
+
+ if (op_type == BinaryOp::Operation_DIV)
+ {
+ if (bottom_blob.dims == 3 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w, h, channels);
+ if (bottom_blob1.w == 1 && bottom_blob1.h == 1)
+ {
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+#if __ARM_NEON
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+ float tt = *ptr1;
+
+ float32x4_t _p2 = vdupq_n_f32(tt);
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+
+ float32x4_t _p3 = vrecpeq_f32(_p2);
+ _p3 = vmulq_f32(vrecpsq_f32(_p2, _p3), _p3);
+ _p1 = vmulq_f32(_p1, _p3);
+
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 / tt);
+ in1++;
+ out++;
+ }
+
+#else
+ float tt = *ptr1;
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = (ptr[i] / tt);
+ }
+#endif
+ }
+
+ // return 0;
+ goto out;
+ }
+ else
+ {
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *in2 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vld1q_f32(in2);
+
+ float32x4_t _p3 = vrecpeq_f32(_p2);
+ _p2 = vmulq_f32(vrecpsq_f32(_p2, _p3), _p3);
+ _p1 = vmulq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ in2 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = *in1 / *in2;
+ in1++;
+ in2++;
+ out++;
+ }
+ }
+ }
+ }
+ else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1)
+ {
+ top_blob.create(w, h, channels);
+ if (bottom_blob1.w == 1)
+ {
+ ret = binary_op<std::divides<float>>(bottom_blob, bottom_blob1, top_blob);
+ // return ret;
+ goto out;
+ }
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float b0 = bottom_blob1[q];
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vdupq_n_f32(b0);
+
+ //_p1 = vsubq_f32(_p1, _p2);
+ float32x4_t _p3 = vrecpeq_f32(_p2);
+ _p2 = vmulq_f32(vrecpsq_f32(_p2, _p3), _p3);
+ _p1 = vmulq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 / b0);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w1, h1, channels1);
+ if (top_blob.empty())
+ return -100;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float a0 = bottom_blob[q];
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size1 >> 2;
+ int remain = size1 - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vdupq_n_f32(a0);
+ float32x4_t _p2 = vld1q_f32(in1);
+
+ //_p1 = vsubq_f32(_p1, _p2);
+ float32x4_t _p3 = vrecpeq_f32(_p2);
+ _p2 = vmulq_f32(vrecpsq_f32(_p2, _p3), _p3);
+ _p1 = vmulq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (a0 / *in1);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else
+ ret = binary_op<std::divides<float>>(bottom_blob, bottom_blob1, top_blob);
+ }
+
+ if (op_type == BinaryOp::Operation_MAX)
+ ret = binary_op<binary_op_max<float>>(bottom_blob, bottom_blob1, top_blob);
+
+ if (op_type == BinaryOp::Operation_MIN)
+ ret = binary_op<binary_op_min<float>>(bottom_blob, bottom_blob1, top_blob);
+
+ if (op_type == BinaryOp::Operation_POW)
+ {
+ if (bottom_blob.dims == 3 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w, h, channels);
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *in2 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vld1q_f32(in2);
+
+ _p1 = pow_ps(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ in2 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = pow(*in1, *in2);
+ in1++;
+ in2++;
+ out++;
+ }
+ }
+ }
+ else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1)
+ {
+ top_blob.create(w, h, channels);
+ if (bottom_blob1.w == 1)
+ {
+ ret = binary_op<binary_op_pow<float>>(bottom_blob, bottom_blob1, top_blob);
+ // return ret;
+ goto out;
+ }
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float b0 = bottom_blob1[q];
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vdupq_n_f32(b0);
+
+ _p1 = pow_ps(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = pow(*in1, b0);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w1, h1, channels1);
+ if (top_blob.empty())
+ return -100;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float a0 = bottom_blob[q];
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size1 >> 2;
+ int remain = size1 - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vdupq_n_f32(a0);
+ float32x4_t _p2 = vld1q_f32(in1);
+
+ _p1 = pow_ps(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = pow(a0, *in1);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else
+ ret = binary_op<binary_op_pow<float>>(bottom_blob, bottom_blob1, top_blob);
+ }
+
+ if (op_type == BinaryOp::Operation_SQUAREDDIFFERENCE)
+ {
+ if (bottom_blob.dims == 3 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w, h, channels);
+
+ if (bottom_blob1.w == 1 && bottom_blob1.h == 1)
+ {
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+#if __ARM_NEON
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+ float tt = *ptr1;
+
+ float32x4_t _p2 = vdupq_n_f32(tt);
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+
+ _p1 = vsubq_f32(_p1, _p2);
+ _p1 = vmulq_f32(_p1, _p1);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ float t2 = *in1 - tt;
+ *out = t2 * t2;
+ in1++;
+ out++;
+ }
+
+#else
+ float tt = *ptr1;
+ for (int i = 0; i < size; i++)
+ {
+ float t2 = (ptr[i] - tt);
+ outptr[i] = t2 * t2;
+ }
+#endif
+ }
+
+ ret = 0;
+ }
+ else
+ {
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *in2 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vld1q_f32(in2);
+
+ _p1 = vsubq_f32(_p1, _p2);
+ _p1 = vmulq_f32(_p1, _p1);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ in2 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 - *in2) * (*in1 - *in2);
+ in1++;
+ in2++;
+ out++;
+ }
+ }
+ }
+ }
+ else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1)
+ {
+ top_blob.create(w, h, channels);
+ if (bottom_blob1.w == 1)
+ {
+ ret = binary_op<binary_op_SquaredDifference<float>>(bottom_blob, bottom_blob1, top_blob);
+ // return ret;
+ goto out;
+ }
+
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ const float *ptr = bottom_blob.channel(q);
+ const float b0 = bottom_blob1[q];
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vdupq_n_f32(b0);
+
+ _p1 = vsubq_f32(_p1, _p2);
+ _p1 = vmulq_f32(_p1, _p1);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (*in1 - b0) * (*in1 - b0);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3)
+ {
+ top_blob.create(w1, h1, channels1);
+ if (top_blob.empty())
+ return -100;
+
+#pragma omp parallel for
+ for (int q = 0; q < channels1; q++)
+ {
+ const float a0 = bottom_blob[q];
+ const float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size1 >> 2;
+ int remain = size1 - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vdupq_n_f32(a0);
+ float32x4_t _p2 = vld1q_f32(in1);
+
+ _p1 = vsubq_f32(_p1, _p2);
+ _p1 = vmulq_f32(_p1, _p1);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = (a0 - *in1) * (a0 - *in1);
+ in1++;
+ out++;
+ }
+ }
+ }
+ else
+ ret = binary_op<binary_op_SquaredDifference<float>>(bottom_blob, bottom_blob1, top_blob);
+ }
+
+#else
+
+ if (op_type == BinaryOp::Operation_ADD)
+ ret = binary_op<std::plus<float>>(bottom_blob, bottom_blob1, top_blob);
+
+ if (op_type == BinaryOp::Operation_SUB)
+ ret = binary_op<std::minus<float>>(bottom_blob, bottom_blob1, top_blob);
+
+ if (op_type == BinaryOp::Operation_MUL)
+ ret = binary_op<std::multiplies<float>>(bottom_blob, bottom_blob1, top_blob);
+
+ if (op_type == BinaryOp::Operation_DIV)
+ ret = binary_op<std::divides<float>>(bottom_blob, bottom_blob1, top_blob);
+
+ if (op_type == BinaryOp::Operation_MAX)
+ ret = binary_op<binary_op_max<float>>(bottom_blob, bottom_blob1, top_blob);
+
+ if (op_type == BinaryOp::Operation_MIN)
+ ret = binary_op<binary_op_min<float>>(bottom_blob, bottom_blob1, top_blob);
+
+ if (op_type == BinaryOp::Operation_POW)
+ ret = binary_op<binary_op_pow<float>>(bottom_blob, bottom_blob1, top_blob);
+ if (op_type == BinaryOp::Operation_SQUAREDDIFFERENCE)
+ ret = binary_op<binary_op_SquaredDifference<float>>(bottom_blob, bottom_blob1, top_blob);
+#endif
+
+/*
+for (int p = 0; p < top_blob.c && p < 5; p++)
+{
+ float* outptr = top_blob.channel(p);
+ printf("channel: %d\n", p);
+ for (int i = 0; i < 1; i++)
+ {
+ for (int j = 0; j < 5; j++)
+ {
+ printf("%f ", outptr[j]);
+ }
+ printf("\n");
+ outptr += top_blob.w;
+ }
+}
+printf("----------------------------\n");
+*/
+
+out:
+ return ret;
+}
+
+int ncnn_binary_op_inplace(const BinaryOpParam ¶m, Mat &bottom_top_blob)
+{
+ auto op_type = param.op_type;
+ auto b = param.b;
+
+ // printf("-------------------BinaryOp-----forward_inplace----------\n");
+ if (op_type == BinaryOp::Operation_ADD)
+ return binary_op_scalar_inplace<std::plus<float>>(bottom_top_blob, b);
+
+ if (op_type == BinaryOp::Operation_SUB)
+ return binary_op_scalar_inplace<std::minus<float>>(bottom_top_blob, b);
+
+ if (op_type == BinaryOp::Operation_MUL)
+ return binary_op_scalar_inplace<std::multiplies<float>>(bottom_top_blob, b);
+
+ if (op_type == BinaryOp::Operation_DIV)
+ return binary_op_scalar_inplace<std::divides<float>>(bottom_top_blob, b);
+
+ if (op_type == BinaryOp::Operation_MAX)
+ return binary_op_scalar_inplace<binary_op_max<float>>(bottom_top_blob, b);
+
+ if (op_type == BinaryOp::Operation_MIN)
+ return binary_op_scalar_inplace<binary_op_min<float>>(bottom_top_blob, b);
+
+ if (op_type == BinaryOp::Operation_POW)
+ return binary_op_scalar_inplace<binary_op_pow<float>>(bottom_top_blob, b);
+
+ if (op_type == BinaryOp::Operation_SQUAREDDIFFERENCE)
+ return binary_op_scalar_inplace<binary_op_SquaredDifference<float>>(bottom_top_blob, b);
+
+ return 0;
+}
+
+int ncnn_binary_op_inplace(const BinaryOpParam ¶m, Mat &bottom_blob, Mat &bottom_top_blob)
+{
+ int ret = 0;
+
+ Mat &bottom_blob1 = bottom_top_blob;
+ Mat &top_blob = bottom_top_blob;
+ auto op_type = param.op_type;
+
+ if (op_type == BinaryOp::Operation_ADD)
+ {
+ int w = bottom_blob.w;
+ int h = bottom_blob.h;
+ int channels = bottom_blob.c;
+ int size = w * h;
+
+ int w1 = bottom_blob1.w;
+ int h1 = bottom_blob1.h;
+ int channels1 = bottom_blob1.c;
+ int size1 = w1 * h1;
+
+#if __ARM_NEON
+
+ if (bottom_blob.dims == 3 && bottom_blob1.dims == 3)
+ {
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ float *ptr = bottom_blob.channel(q);
+ float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ int nn = size >> 2;
+ int remain = size - (nn << 2);
+
+ float *in1 = const_cast<float *>(ptr);
+ float *in2 = const_cast<float *>(ptr1);
+ float *out = const_cast<float *>(outptr);
+
+ for (; nn > 0; nn--)
+ {
+ float32x4_t _p1 = vld1q_f32(in1);
+ float32x4_t _p2 = vld1q_f32(in2);
+
+ _p1 = vaddq_f32(_p1, _p2);
+ vst1q_f32(out, _p1);
+ in1 += 4;
+ in2 += 4;
+ out += 4;
+ }
+ for (; remain > 0; remain--)
+ {
+ *out = *in1 + *in2;
+ in1++;
+ in2++;
+ out++;
+ }
+ }
+ }
+#else
+ if (bottom_blob.dims == 3 && bottom_blob1.dims == 3)
+ {
+#pragma omp parallel for
+ for (int q = 0; q < channels; q++)
+ {
+ float *ptr = bottom_blob.channel(q);
+ float *ptr1 = bottom_blob1.channel(q);
+ float *outptr = top_blob.channel(q);
+
+ for (int i = 0; i < size; i++)
+ {
+ outptr[i] = ptr[i] + ptr1[i];
+ }
+ }
+ return 0;
+ }
+#endif
+ }
+ else
+ {
+ return -1;
+ }
+ return ret;
+}
+
+} // namespace ncnn
+} // namespace ncnn