default:
LOG(FATAL) << "Unhandled final quantization type "
<< static_cast<int>(array.final_data_type);
- return default_type;
+ }
+}
+
+void GetQuantizationParams(ArrayDataType data_type,
+ const ModelFlags& model_flags, const MinMax& minmax,
+ QuantizationParams* quantization_params) {
+ switch (data_type) {
+ case ArrayDataType::kInt8:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt8>(
+ model_flags, minmax, quantization_params);
+ break;
+ case ArrayDataType::kUint8:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ model_flags, minmax, quantization_params);
+ break;
+ case ArrayDataType::kInt16:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
+ model_flags, minmax, quantization_params);
+ break;
+ case ArrayDataType::kUint16:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint16>(
+ model_flags, minmax, quantization_params);
+ break;
+ case ArrayDataType::kInt32:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt32>(
+ model_flags, minmax, quantization_params);
+ break;
+ case ArrayDataType::kUint32:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint32>(
+ model_flags, minmax, quantization_params);
+ break;
+ case ArrayDataType::kInt64:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt64>(
+ model_flags, minmax, quantization_params);
+ break;
+ case ArrayDataType::kUint64:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint64>(
+ model_flags, minmax, quantization_params);
+ break;
+ case ArrayDataType::kFloat:
+ case ArrayDataType::kNone:
+ default:
+ LOG(FATAL) << "Unhandled final quantization type "
+ << static_cast<int>(data_type);
}
}
if (op.type == OperatorType::kLstmCell) {
if (input_index == LstmCellOperator::PREV_STATE_INPUT) {
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
- model->flags, minmax, quantization_params);
*quantized_data_type = ArrayDataType::kInt16;
+ GetQuantizationParams(*quantized_data_type, model->flags, minmax,
+ quantization_params);
return true;
}
}
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
- quantization_params);
*quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
+ GetQuantizationParams(*quantized_data_type, model->flags, minmax,
+ quantization_params);
transformation->AddMessageF(
"For input array %s with min=%g"
", max=%g"
if (op.type == OperatorType::kLstmCell) {
if (output_index == LstmCellOperator::STATE_OUTPUT ||
output_index == LstmCellOperator::ACTIV_TEMP) {
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
- model->flags, minmax, quantization_params);
*quantized_data_type = ArrayDataType::kInt16;
+ GetQuantizationParams(*quantized_data_type, model->flags, minmax,
+ quantization_params);
return true;
}
}
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
- quantization_params);
*quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
+ GetQuantizationParams(*quantized_data_type, model->flags, minmax,
+ quantization_params);
transformation->AddMessageF(
"For output array %s with min=%g, max=%g"
", chose to quantize as %s with zero_point=%d"