Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / FullyConnected.cpp
index 48433b4..cfe8f8b 100644 (file)
@@ -19,6 +19,7 @@
 #include "kernels/Utils.h"
 
 #include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
 
 #include <stdexcept>
 
@@ -48,6 +49,12 @@ void FullyConnected::configure()
     LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
     LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
   }
+  else if (weights()->element_type() == DataType::S8)
+  {
+    LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::S8);
+    LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8);
+    LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
+  }
   else
   {
     throw std::runtime_error("Unsupported type.");
@@ -77,6 +84,9 @@ void FullyConnected::execute() const
     case DataType::U8:
       evalQuantized();
       break;
+    case DataType::S8:
+      evalQuantizedS8();
+      break;
     case DataType::FLOAT32:
       evalFloat();
       break;
@@ -135,5 +145,38 @@ void FullyConnected::evalQuantized() const
     getTensorShape(output()), getTensorData<uint8_t>(output()));
 }
 
+void FullyConnected::evalQuantizedS8() const
+{
+  double real_multiplier = 0.0;
+  int output_shift;
+  int32_t output_activation_min;
+  int32_t output_activation_max;
+  int32_t output_multiplier;
+  real_multiplier =
+    getQuantizedConvolutionMultipler(input()->scale(), weights()->scale(), output()->scale());
+  quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
+  calculateActivationRangeQuantized(params().activation, output(), &output_activation_min,
+                                    &output_activation_max);
+
+  int32_t input_offset = -input()->zero_point();
+  int32_t filter_offset = -weights()->zero_point();
+  int32_t output_offset = output()->zero_point();
+
+  tflite::FullyConnectedParams op_params{};
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  op_params.output_shift = output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+  op_params.lhs_cacheable = false;
+  op_params.rhs_cacheable = false;
+  tflite::reference_integer_ops::FullyConnected(
+    op_params, getTensorShape(input()), getTensorData<int8_t>(input()), getTensorShape(weights()),
+    getTensorData<int8_t>(weights()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
+    getTensorShape(output()), getTensorData<int8_t>(output()));
+}
+
 } // namespace kernels
 } // namespace luci_interpreter