Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / ReduceLayer.cc
index 1dad031..fe22dbe 100644 (file)
@@ -116,6 +116,39 @@ void evalGeneric(const IPortableTensor *input, IPortableTensor *output,
       throw std::runtime_error{"Reduce(generic): unsupported data type"};
   }
 }
+
+void evalSumQuantized(const IPortableTensor *input, IPortableTensor *output,
+                      const std::vector<int> &axes, bool keep_dims,
+                      nnfw::cker::Reduce &reduce_kernel)
+{
+  const bool same_scale = (input->data_scale() == output->data_scale() &&
+                           input->data_offset() == output->data_offset());
+
+  reduce_kernel.prepare(input->num_dimensions(), axes.size());
+
+  if (!same_scale)
+  {
+    std::vector<int32_t> temp_sum(output->getShape().num_elements());
+    bool result = reduce_kernel.QuantizedMeanOrSum<uint8_t, int32_t>(
+        reinterpret_cast<const uint8_t *>(input->buffer()), input->data_offset(),
+        input->data_scale(), getTensorShape(input), reinterpret_cast<uint8_t *>(output->buffer()),
+        output->data_offset(), output->data_scale(), getTensorShape(output), axes, keep_dims,
+        temp_sum.data(), true, [](const int32_t current, const uint8_t in) -> int32_t {
+          const int32_t actual_in = static_cast<int32_t>(in);
+          return current + actual_in;
+        });
+
+    if (!result)
+    {
+      throw std::runtime_error{"Reduce: Fail to run"};
+    }
+
+    return;
+  }
+
+  evalGeneric<ReduceType::kSum>(input, output, axes, keep_dims, reduce_kernel);
+}
+
 } // namespace
 
 ReduceLayer::ReduceLayer()
@@ -143,6 +176,11 @@ void ReduceLayer::run()
   switch (_reduceType)
   {
     case ReduceType::kSum:
+      if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
+      {
+        evalSumQuantized(_input, _output, axes, _keep_dims, *_reduce_kernel);
+        return;
+      }
       evalGeneric<ReduceType::kSum>(_input, _output, axes, _keep_dims, *_reduce_kernel);
       break;
     case ReduceType::kProd: