From 7dc7ad2d69fc6935dfff7fcf7c67df89562dd8da Mon Sep 17 00:00:00 2001 From: Inki Dae Date: Fri, 3 Jan 2025 14:12:24 +0900 Subject: [PATCH] update quantization parameters correctly Update quantization parameters correctly for output tensor if output layer of the given model has quantization parameters. Change-Id: Id278e55882f7b6135dcf55f83c050f2df4f9d98c Signed-off-by: Inki Dae --- src/inference_engine_tflite.cpp | 36 ++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/inference_engine_tflite.cpp b/src/inference_engine_tflite.cpp index 74dc0c3..caa1bd6 100644 --- a/src/inference_engine_tflite.cpp +++ b/src/inference_engine_tflite.cpp @@ -303,31 +303,34 @@ namespace TFLiteImpl return INFERENCE_ENGINE_ERROR_INVALID_OPERATION; } + const TfLiteTensor *tensor = mInterpreter->tensor(layer.second); + if (!tensor) { + LOGE("tensor for tensor index(%d) is null", layer.second); + return INFERENCE_ENGINE_ERROR_INVALID_OPERATION; + } + inference_engine_tensor_info tensor_info; LOGI("mInterpreter->tensor(%d)->dims name[%s] size[%d] type[%d]", - layer.second, - mInterpreter->tensor(layer.second)->name, - mInterpreter->tensor(layer.second)->dims->size, - mInterpreter->tensor(layer.second)->type); + layer.second, tensor->name, tensor->dims->size, tensor->type); std::vector shape_nhwc; - for (int idx = 0; idx < mInterpreter->tensor(layer.second)->dims->size; idx++) - shape_nhwc.push_back(mInterpreter->tensor(layer.second)->dims->data[idx]); + for (int idx = 0; idx < tensor->dims->size; idx++) + shape_nhwc.push_back(tensor->dims->data[idx]); //tflite only supports NHWC (https://www.tensorflow.org/lite/guide/ops_compatibility). tensor_info.shape = shape_nhwc; tensor_info.shape_type = INFERENCE_TENSOR_SHAPE_NHWC; - if (mInterpreter->tensor(layer.second)->type == kTfLiteUInt8) { + if (tensor->type == kTfLiteUInt8) { LOGI("type is kTfLiteUInt8"); tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_UINT8; - } else if (mInterpreter->tensor(layer.second)->type == kTfLiteInt8) { + } else if (tensor->type== kTfLiteInt8) { LOGI("type is kTfLiteInt8"); tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_INT8; - } else if (mInterpreter->tensor(layer.second)->type == kTfLiteInt64) { + } else if (tensor->type == kTfLiteInt64) { LOGI("type is kTfLiteInt64"); tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_INT64; - } else if (mInterpreter->tensor(layer.second)->type == kTfLiteFloat32) { + } else if (tensor->type == kTfLiteFloat32) { LOGI("type is kTfLiteFloat32"); tensor_info.data_type = INFERENCE_TENSOR_DATA_TYPE_FLOAT32; } else { @@ -339,6 +342,19 @@ namespace TFLiteImpl for (auto & dim : tensor_info.shape) tensor_info.size *= dim; + if (tensor->quantization.type == kTfLiteAffineQuantization) { + auto *quant_parms = reinterpret_cast(tensor->quantization.params); + + LOGD("This layer has quantization parameters."); + if (quant_parms) { + tensor_info.scale = quant_parms->scale->data[0]; + tensor_info.zero_point = quant_parms->zero_point->data[0]; + tensor_info.quantization_type = INFERENCE_TENSOR_QUANTIZATION_AFFINE; + + LOGD("Quantization params : type(%d), scale(%f), zero point(%d)", tensor_info.quantization_type, tensor_info.scale, tensor_info.zero_point); + } + } + mOutputLayers.insert(std::make_pair(mInterpreter->tensor(layer.second)->name, tensor_info)); } -- 2.34.1