Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / PRelu.cpp
index a53ac6f..5a6b05c 100644 (file)
@@ -19,7 +19,8 @@
 #include "kernels/BinaryOpCommon.h"
 #include "kernels/Utils.h"
 
-#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+#include <tensorflow/lite/kernels/internal/reference/binary_function.h>
+#include <tensorflow/lite/kernels/internal/reference/prelu.h>
 
 #include <stdexcept>
 
@@ -168,10 +169,11 @@ static inline int16_t evalElemS16PRelu(int16_t input_val, int16_t alpha_val,
   constexpr int32_t quantized_max = std::numeric_limits<int16_t>::max();
 
   const int32_t output_val =
-    input_val >= 0 ? tflite::MultiplyByQuantizedMultiplier(input_val, identity_mult.multiplier,
-                                                           identity_mult.shift)
-                   : tflite::MultiplyByQuantizedMultiplier(input_val * alpha_val,
-                                                           alpha_mult.multiplier, alpha_mult.shift);
+    input_val >= 0
+      ? tflite::MultiplyByQuantizedMultiplier(static_cast<int32_t>(input_val),
+                                              identity_mult.multiplier, identity_mult.shift)
+      : tflite::MultiplyByQuantizedMultiplier(static_cast<int32_t>(input_val * alpha_val),
+                                              alpha_mult.multiplier, alpha_mult.shift);
   const int32_t clamped_output = std::min(quantized_max, std::max(quantized_min, output_val));
   return clamped_output;
 }