#include "kernels/BinaryOpCommon.h"
#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+#include <tensorflow/lite/kernels/internal/reference/binary_function.h>
+#include <tensorflow/lite/kernels/internal/reference/prelu.h>
#include <stdexcept>
constexpr int32_t quantized_max = std::numeric_limits<int16_t>::max();
const int32_t output_val =
- input_val >= 0 ? tflite::MultiplyByQuantizedMultiplier(input_val, identity_mult.multiplier,
- identity_mult.shift)
- : tflite::MultiplyByQuantizedMultiplier(input_val * alpha_val,
- alpha_mult.multiplier, alpha_mult.shift);
+ input_val >= 0
+ ? tflite::MultiplyByQuantizedMultiplier(static_cast<int32_t>(input_val),
+ identity_mult.multiplier, identity_mult.shift)
+ : tflite::MultiplyByQuantizedMultiplier(static_cast<int32_t>(input_val * alpha_val),
+ alpha_mult.multiplier, alpha_mult.shift);
const int32_t clamped_output = std::min(quantized_max, std::max(quantized_min, output_val));
return clamped_output;
}