From: Inki Dae Date: Fri, 18 Oct 2024 08:28:19 +0000 (+0900) Subject: add INT8 data type support X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=refs%2Fheads%2Ftizen_devel;p=platform%2Fcore%2Fmultimedia%2Finference-engine-tflite.git add INT8 data type support Change-Id: If85622bc22988dcf0226df8f08d9c21543e46a3e Signed-off-by: Inki Dae --- diff --git a/src/inference_engine_tflite.cpp b/src/inference_engine_tflite.cpp index ee52348..3697e71 100644 --- a/src/inference_engine_tflite.cpp +++ b/src/inference_engine_tflite.cpp @@ -176,6 +176,12 @@ namespace TFLiteImpl case INFERENCE_TENSOR_DATA_TYPE_UINT8: pBuff = static_cast(mInterpreter->typed_tensor(mInputLayerId[layer.first])); buffer = { pBuff, INFERENCE_TENSOR_DATA_TYPE_UINT8, size, 1 }; + LOGD("buffer type is UINT8"); + break; + case INFERENCE_TENSOR_DATA_TYPE_INT8: + pBuff = static_cast(mInterpreter->typed_tensor(mInputLayerId[layer.first])); + buffer = { pBuff, INFERENCE_TENSOR_DATA_TYPE_INT8, size, 1 }; + LOGD("buffer type is INT8"); break; case INFERENCE_TENSOR_DATA_TYPE_FLOAT32: pBuff = static_cast(mInterpreter->typed_tensor(mInputLayerId[layer.first])); @@ -185,6 +191,7 @@ namespace TFLiteImpl LOGE("Not supported"); return INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT; } + buffers.insert(std::make_pair(layer.first, buffer)); } @@ -216,6 +223,11 @@ namespace TFLiteImpl pBuff = static_cast(mInterpreter->typed_tensor(mOutputLayerId[layer.first])); buffer = { pBuff, INFERENCE_TENSOR_DATA_TYPE_UINT8, size, 1 }; break; + case kTfLiteInt8: + LOGI("type is kTfLiteInt8"); + pBuff = static_cast(mInterpreter->typed_tensor(mOutputLayerId[layer.first])); + buffer = { pBuff, INFERENCE_TENSOR_DATA_TYPE_INT8, size, 1 }; + break; case kTfLiteInt64: LOGI("type is kTfLiteInt64"); pBuff = static_cast(mInterpreter->typed_tensor(mOutputLayerId[layer.first])); @@ -281,6 +293,9 @@ namespace TFLiteImpl if (mInterpreter->tensor(layer.second)->type == kTfLiteUInt8) { LOGI("type is kTfLiteUInt8"); tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_UINT8; + } else if (mInterpreter->tensor(layer.second)->type == kTfLiteInt8) { + LOGI("type is kTfLiteInt8"); + tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_INT8; } else if (mInterpreter->tensor(layer.second)->type == kTfLiteInt64) { LOGI("type is kTfLiteInt64"); tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_INT64; @@ -362,11 +377,16 @@ namespace TFLiteImpl if (mIsDynamicTensorMode) for (auto &input_buffer : input_buffers) { void *pBuff; + switch (mInterpreter->tensor(mInputLayerId[input_buffer.first])->type) { case kTfLiteUInt8: LOGI("type is kTfLiteUInt8"); pBuff = static_cast(mInterpreter->typed_tensor(mInputLayerId[input_buffer.first])); break; + case kTfLiteInt8: + LOGI("type is kTfLiteInt8"); + pBuff = static_cast(mInterpreter->typed_tensor(mInputLayerId[input_buffer.first])); + break; case kTfLiteInt64: LOGI("type is kTfLiteInt64"); pBuff = static_cast(mInterpreter->typed_tensor(mInputLayerId[input_buffer.first])); @@ -379,6 +399,7 @@ namespace TFLiteImpl LOGE("Not supported"); return INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT; } + memcpy(pBuff, input_buffer.second.buffer, input_buffer.second.size); } @@ -397,6 +418,10 @@ namespace TFLiteImpl LOGI("type is kTfLiteUInt8"); pBuff = static_cast(mInterpreter->typed_tensor(mOutputLayerId[output_buffer.first])); break; + case kTfLiteInt8: + LOGI("type is kTfLiteInt8"); + pBuff = static_cast(mInterpreter->typed_tensor(mOutputLayerId[output_buffer.first])); + break; case kTfLiteInt64: LOGI("type is kTfLiteInt64"); pBuff = static_cast(mInterpreter->typed_tensor(mOutputLayerId[output_buffer.first])); @@ -487,6 +512,10 @@ namespace TFLiteImpl LOGI("type is kTfLiteUInt8"); tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_UINT8; break; + case kTfLiteInt8: + LOGI("type is kTfLiteInt8"); + tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_INT8; + break; case kTfLiteFloat32: LOGI("type is kTfLiteFloat32"); tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_FLOAT32;