void QuantizeArray(GraphTransformation* transformation, Model* model,
const string& name, ArrayDataType quantized_data_type,
const QuantizationParams& quantization_params) {
- switch (quantized_data_type) {
+ ArrayDataType adjusted_data_type = quantized_data_type;
+ auto& array = model->GetArray(name);
+ if (array.final_data_type == ArrayDataType::kInt16) {
+ adjusted_data_type = array.final_data_type;
+ }
+
+ switch (adjusted_data_type) {
case ArrayDataType::kUint8:
return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
quantization_params);
"proceed with quantization.";
}
+struct QuantizationPoints {
+ int64 min_value;
+ int64 max_value;
+ int64 central_value;
+};
+
+template <ArrayDataType A>
+QuantizationPoints GetQuantizationPoints() {
+ QuantizationPoints qp;
+ using Integer = DataType<A>;
+ qp.min_value = std::numeric_limits<Integer>::min();
+ qp.max_value = std::numeric_limits<Integer>::max();
+ // eg [-128,127]...
+ qp.central_value = (qp.min_value / 2 + // -128 -> -64.
+ (qp.max_value - 1) / 2 + // 127 -> 63.
+ 1);
+ return qp;
+}
+
+QuantizationPoints GetQuantizationPoints(ArrayDataType data_type) {
+ switch (data_type) {
+ case ArrayDataType::kUint8:
+ return GetQuantizationPoints<ArrayDataType::kUint8>();
+ case ArrayDataType::kInt16:
+ return GetQuantizationPoints<ArrayDataType::kInt16>();
+ case ArrayDataType::kInt32:
+ return GetQuantizationPoints<ArrayDataType::kInt32>();
+ default:
+ LOG(FATAL) << "Unhandled case.";
+ }
+}
+
+ArrayDataType GetQuantizedDataType(const Array& array,
+ ArrayDataType default_type) {
+ switch (array.final_data_type) {
+ case ArrayDataType::kInt8:
+ case ArrayDataType::kUint8:
+ case ArrayDataType::kInt16:
+ case ArrayDataType::kUint16:
+ case ArrayDataType::kInt32:
+ case ArrayDataType::kUint32:
+ case ArrayDataType::kInt64:
+ case ArrayDataType::kUint64:
+ return array.final_data_type;
+ case ArrayDataType::kFloat:
+ case ArrayDataType::kNone:
+ return default_type;
+ default:
+ LOG(FATAL) << "Unhandled final quantization type "
+ << static_cast<int>(array.final_data_type);
+ return default_type;
+ }
+}
+
bool ChooseQuantizationForOperatorInput(
GraphTransformation* transformation, Model* model, const Operator& op,
std::size_t input_index, ArrayDataType* quantized_data_type,
const auto input_weights_scale = input_weights.quantization_params->scale;
quantization_params->scale = input_activations_scale * input_weights_scale;
quantization_params->zero_point = 0;
- *quantized_data_type = ArrayDataType::kInt32;
+ *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kInt32);
transformation->AddMessageF(
"Input array %s is a bias vector. Choosing quantization params "
"accordingly.",
GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
quantization_params);
+ *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
transformation->AddMessageF(
"For input array %s with min=%g"
", max=%g"
- ", chose to quantize as uint8 with zero_point=%d"
+ ", chose to quantize as %s with zero_point=%d"
", scale=%g",
- input, minmax.min, minmax.max, quantization_params->zero_point,
- quantization_params->scale);
- *quantized_data_type = ArrayDataType::kUint8;
+ input, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type),
+ quantization_params->zero_point, quantization_params->scale);
return true;
}
return true;
}
+// Quantized data type is preset to the type of the input before this function.
bool ChooseHardcodedQuantizationForOperatorOutput(
- const Operator& op, ArrayDataType* quantized_data_type,
+ const Operator& op, const Array& array, ArrayDataType* quantized_data_type,
QuantizationParams* quantization_params) {
if (op.type == OperatorType::kL2Normalization) {
// L2Normalization has range: [-1, 1].
// 0 should be exactly representable, as values will typically be centered
// around 0, with many values near 0.
- *quantized_data_type = ArrayDataType::kUint8;
- quantization_params->zero_point = 128;
- quantization_params->scale = 1. / 128.;
+ *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
+ const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
+ quantization_params->zero_point = qp.central_value;
+ quantization_params->scale = 1. / (qp.central_value - qp.min_value);
CHECK(
IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
return true;
// will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and
// the glueing of the two halves of the graph will only be seamless if we
// are accurately representing logistic(0) == 0.5.
- *quantized_data_type = ArrayDataType::kUint8;
+ *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
+ const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
quantization_params->zero_point = 0;
- quantization_params->scale = 1. / 256.;
+ quantization_params->scale = 1. / (qp.max_value + 1);
CHECK(IsExactlyRepresentable(0.5, *quantized_data_type,
*quantization_params));
return true;
}
if (op.type == OperatorType::kTanh) {
// Tanh has the range: [-1, 1].
- *quantized_data_type = ArrayDataType::kUint8;
- quantization_params->zero_point = 128;
- quantization_params->scale = 1. / 128.;
+ *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
+ const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
+ quantization_params->zero_point = qp.central_value;
+ quantization_params->scale = 1. / (qp.central_value - qp.min_value);
// 0 should be exactly representable, as values will typically be centered
// around 0, with many values near 0.
CHECK(
if (array.data_type != ArrayDataType::kFloat) {
return false;
}
- if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type,
- quantization_params)) {
+ *quantized_data_type = model->GetArray(op.inputs[0]).data_type;
+ if (ChooseHardcodedQuantizationForOperatorOutput(
+ op, array, quantized_data_type, quantization_params)) {
transformation->AddMessageF(
"Output array %s is produced by a %s operator. Choosing fixed "
"quantization params accordingly.",
return true;
}
if ((op.type == OperatorType::kDepthToSpace) ||
- (op.type == OperatorType::kSpaceToDepth)) {
- // DepthToSpace and SpaceToDepth should preserve the quantization parameters
- // of the input array, as these are simple reshape operations.
- const auto& input_quantization_params =
- model->GetArray(op.inputs[0]).GetQuantizationParams();
- *quantized_data_type = ArrayDataType::kUint8;
+ (op.type == OperatorType::kSpaceToDepth) ||
+ (op.type == OperatorType::kTensorFlowReshape) ||
+ (op.type == OperatorType::kTensorFlowSplit) ||
+ (op.type == OperatorType::kConcatenation)) {
+ int data_input_index = 0;
+ if (op.type == OperatorType::kTensorFlowSplit) {
+ data_input_index = 1;
+ }
+ // Copying and rearrangement ops should preserve the quantization parameters
+ // of the input array.
+ const auto& input_array = model->GetArray(op.inputs[data_input_index]);
+ const auto& input_quantization_params = input_array.GetQuantizationParams();
+ *quantized_data_type =
+ GetQuantizedDataType(input_array, ArrayDataType::kUint8);
+ *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
quantization_params->zero_point = input_quantization_params.zero_point;
quantization_params->scale = input_quantization_params.scale;
}
GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
quantization_params);
- *quantized_data_type = ArrayDataType::kUint8;
+ *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
transformation->AddMessageF(
"For output array %s with min=%g, max=%g"
- ", chose to quantize as uint8 with zero_point=%d"
+ ", chose to quantize as %s with zero_point=%d"
", scale=%g",
- output, minmax.min, minmax.max, quantization_params->zero_point,
- quantization_params->scale);
+ output, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type),
+ quantization_params->zero_point, quantization_params->scale);
return true;
}