#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+#include <tensorflow/lite/kernels/internal/reference/log_softmax.h>
-#include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
+#include "PALLogSoftmax.h"
namespace luci_interpreter
{
params.table = _table;
params.beta = 1.0;
-
- tflite::optimized_ops::PopulateSoftmaxLookupTable(¶ms, input()->scale(), params.beta);
+ luci_interpreter_pal::PopulateSoftmaxLookupTable(¶ms, input()->scale(), params.beta);
}
output()->resize(input()->shape());
}
const auto input_scale = input()->scale();
uint8_t *output_data = getTensorData<uint8_t>(output());
const uint8_t *input_data = getTensorData<uint8_t>(input());
+ const float beta = 1.0;
tflite::SoftmaxParams params{};
params.zero_point = output()->zero_point();
params.scale = output()->scale();
- tflite::optimized_ops::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
- output_data);
+ luci_interpreter_pal::InitializeParams(¶ms, input_scale, beta);
+ luci_interpreter_pal::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
+ output_data);
}
} // namespace kernels