const auto in1_shape = shapeToRuntimeShape(in1.getShape());
const auto in2_shape = shapeToRuntimeShape(in2.getShape());
assert(in2_shape.DimensionsCount() == 1 &&
- in2_shape.Dims(0) == in1_shape.Dims(in1_shape.DimensionsCount() - 1));
+ in2_shape.Dims(0) == in1_shape.Dims(in1_shape.DimensionsCount() - 1));
out.reshape(in1.getShape());
+
+#ifdef USE_NEON
+ const int scale_size = in2_shape.FlatSize();
+ const int array_size = in1_shape.FlatSize();
+ TFLITE_DCHECK_EQ((array_size % scale_size), 0);
+ out.fillData(in1.getData(), array_size);
+ float* array_ptr = out.getData();
+ const float* scale_ptr = in2.getData();
+ float* array_end_ptr = array_ptr + array_size;
+ for (; array_ptr != array_end_ptr; array_ptr += scale_size) {
+ int i = 0;
+ for (; i <= scale_size - 16; i += 16) {
+ auto b0 = vld1q_f32(scale_ptr + i);
+ auto b1 = vld1q_f32(scale_ptr + i + 4);
+ auto b2 = vld1q_f32(scale_ptr + i + 8);
+ auto b3 = vld1q_f32(scale_ptr + i + 12);
+ auto a0 = vld1q_f32(array_ptr + i);
+ auto a1 = vld1q_f32(array_ptr + i + 4);
+ auto a2 = vld1q_f32(array_ptr + i + 8);
+ auto a3 = vld1q_f32(array_ptr + i + 12);
+ auto x0 = vmulq_f32(a0, b0);
+ auto x1 = vmulq_f32(a1, b1);
+ auto x2 = vmulq_f32(a2, b2);
+ auto x3 = vmulq_f32(a3, b3);
+ vst1q_f32(array_ptr + i, x0);
+ vst1q_f32(array_ptr + i + 4, x1);
+ vst1q_f32(array_ptr + i + 8, x2);
+ vst1q_f32(array_ptr + i + 12, x3);
+ }
+ for (; i <= scale_size - 4; i += 4) {
+ auto b = vld1q_f32(scale_ptr + i);
+ auto a = vld1q_f32(array_ptr + i);
+ auto x = vmulq_f32(a, b);
+ vst1q_f32(array_ptr + i, x);
+ }
+ for (; i < scale_size; i++) {
+ array_ptr[i] = array_ptr[i] * scale_ptr[i];
+ }
+ }
+#else // not NEON
+
const auto out_shape = shapeToRuntimeShape(out.getShape());
const auto in1_mat = MapAsMatrixWithLastDimAsRows(in1.getData(), in1_shape);
auto out_mat = MapAsMatrixWithLastDimAsRows(out.getData(), out_shape);
out_mat.colwise() = in2_vec;
out_mat.array() = out_mat.array() * in1_mat.array();
+#endif
}