add INT8 data type support 41/319341/2 tizen_devel
authorInki Dae <inki.dae@samsung.com>
Fri, 18 Oct 2024 08:28:19 +0000 (17:28 +0900)
committerInki Dae <inki.dae@samsung.com>
Mon, 21 Oct 2024 22:57:45 +0000 (07:57 +0900)
Change-Id: If85622bc22988dcf0226df8f08d9c21543e46a3e
Signed-off-by: Inki Dae <inki.dae@samsung.com>
src/inference_engine_tflite.cpp

index ee52348234c262d4178717e840c8ad88759e6855..3697e71d1110725a23ae6840e5352e3dc14a911f 100644 (file)
@@ -176,6 +176,12 @@ namespace TFLiteImpl
                        case INFERENCE_TENSOR_DATA_TYPE_UINT8:
                                pBuff = static_cast<void *>(mInterpreter->typed_tensor<uint8_t>(mInputLayerId[layer.first]));
                                buffer = { pBuff, INFERENCE_TENSOR_DATA_TYPE_UINT8, size, 1 };
+                               LOGD("buffer type is UINT8");
+                               break;
+                       case INFERENCE_TENSOR_DATA_TYPE_INT8:
+                               pBuff = static_cast<void *>(mInterpreter->typed_tensor<int8_t>(mInputLayerId[layer.first]));
+                               buffer = { pBuff, INFERENCE_TENSOR_DATA_TYPE_INT8, size, 1 };
+                               LOGD("buffer type is INT8");
                                break;
                        case INFERENCE_TENSOR_DATA_TYPE_FLOAT32:
                                pBuff = static_cast<void *>(mInterpreter->typed_tensor<float>(mInputLayerId[layer.first]));
@@ -185,6 +191,7 @@ namespace TFLiteImpl
                                LOGE("Not supported");
                                return INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT;
                        }
+
                        buffers.insert(std::make_pair(layer.first, buffer));
                }
 
@@ -216,6 +223,11 @@ namespace TFLiteImpl
                                pBuff = static_cast<void *>(mInterpreter->typed_tensor<uint8_t>(mOutputLayerId[layer.first]));
                                buffer = { pBuff, INFERENCE_TENSOR_DATA_TYPE_UINT8, size, 1 };
                                break;
+                       case kTfLiteInt8:
+                               LOGI("type is kTfLiteInt8");
+                               pBuff = static_cast<void *>(mInterpreter->typed_tensor<int8_t>(mOutputLayerId[layer.first]));
+                               buffer = { pBuff, INFERENCE_TENSOR_DATA_TYPE_INT8, size, 1 };
+                               break;
                        case kTfLiteInt64:
                                LOGI("type is kTfLiteInt64");
                                pBuff = static_cast<void *>(mInterpreter->typed_tensor<int64_t>(mOutputLayerId[layer.first]));
@@ -281,6 +293,9 @@ namespace TFLiteImpl
                        if (mInterpreter->tensor(layer.second)->type == kTfLiteUInt8) {
                                LOGI("type is kTfLiteUInt8");
                                tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_UINT8;
+                       } else if (mInterpreter->tensor(layer.second)->type == kTfLiteInt8) {
+                               LOGI("type is kTfLiteInt8");
+                               tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_INT8;
                        } else if (mInterpreter->tensor(layer.second)->type == kTfLiteInt64) {
                                LOGI("type is kTfLiteInt64");
                                tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_INT64;
@@ -362,11 +377,16 @@ namespace TFLiteImpl
                if (mIsDynamicTensorMode)
                        for (auto &input_buffer : input_buffers) {
                                void *pBuff;
+
                                switch (mInterpreter->tensor(mInputLayerId[input_buffer.first])->type) {
                                case kTfLiteUInt8:
                                        LOGI("type is kTfLiteUInt8");
                                        pBuff = static_cast<void *>(mInterpreter->typed_tensor<uint8_t>(mInputLayerId[input_buffer.first]));
                                        break;
+                               case kTfLiteInt8:
+                                       LOGI("type is kTfLiteInt8");
+                                       pBuff = static_cast<void *>(mInterpreter->typed_tensor<int8_t>(mInputLayerId[input_buffer.first]));
+                                       break;
                                case kTfLiteInt64:
                                        LOGI("type is kTfLiteInt64");
                                        pBuff = static_cast<void *>(mInterpreter->typed_tensor<int64_t>(mInputLayerId[input_buffer.first]));
@@ -379,6 +399,7 @@ namespace TFLiteImpl
                                        LOGE("Not supported");
                                        return INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT;
                                }
+
                                memcpy(pBuff, input_buffer.second.buffer, input_buffer.second.size);
                        }
 
@@ -397,6 +418,10 @@ namespace TFLiteImpl
                                        LOGI("type is kTfLiteUInt8");
                                        pBuff = static_cast<void *>(mInterpreter->typed_tensor<uint8_t>(mOutputLayerId[output_buffer.first]));
                                        break;
+                               case kTfLiteInt8:
+                                       LOGI("type is kTfLiteInt8");
+                                       pBuff = static_cast<void *>(mInterpreter->typed_tensor<int8_t>(mOutputLayerId[output_buffer.first]));
+                                       break;
                                case kTfLiteInt64:
                                        LOGI("type is kTfLiteInt64");
                                        pBuff = static_cast<void *>(mInterpreter->typed_tensor<int64_t>(mOutputLayerId[output_buffer.first]));
@@ -487,6 +512,10 @@ namespace TFLiteImpl
                                LOGI("type is kTfLiteUInt8");
                                tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_UINT8;
                                break;
+                       case kTfLiteInt8:
+                               LOGI("type is kTfLiteInt8");
+                               tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_INT8;
+                               break;
                        case kTfLiteFloat32:
                                LOGI("type is kTfLiteFloat32");
                                tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_FLOAT32;