Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / LogSoftmax.cpp
index 03d13e4..79c3153 100644 (file)
@@ -18,9 +18,9 @@
 
 #include "kernels/Utils.h"
 
-#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+#include <tensorflow/lite/kernels/internal/reference/log_softmax.h>
 
-#include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
+#include "PALLogSoftmax.h"
 
 namespace luci_interpreter
 {
@@ -41,8 +41,7 @@ void LogSoftmax::configure()
 
     params.table = _table;
     params.beta = 1.0;
-
-    tflite::optimized_ops::PopulateSoftmaxLookupTable(&params, input()->scale(), params.beta);
+    luci_interpreter_pal::PopulateSoftmaxLookupTable(&params, input()->scale(), params.beta);
   }
   output()->resize(input()->shape());
 }
@@ -76,6 +75,7 @@ void LogSoftmax::evalQuantized() const
   const auto input_scale = input()->scale();
   uint8_t *output_data = getTensorData<uint8_t>(output());
   const uint8_t *input_data = getTensorData<uint8_t>(input());
+  const float beta = 1.0;
 
   tflite::SoftmaxParams params{};
 
@@ -83,8 +83,9 @@ void LogSoftmax::evalQuantized() const
   params.zero_point = output()->zero_point();
   params.scale = output()->scale();
 
-  tflite::optimized_ops::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
-                                    output_data);
+  luci_interpreter_pal::InitializeParams(&params, input_scale, beta);
+  luci_interpreter_pal::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
+                                   output_data);
 }
 
 } // namespace kernels