Fix up output array min/max post-quantization if the range was overridden.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 20:24:06 +0000 (13:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 20:26:25 +0000 (13:26 -0700)
PiperOrigin-RevId: 191637143

tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
tensorflow/contrib/lite/toco/graph_transformations/quantize.cc

index 7c97ef0..23c9e32 100644 (file)
@@ -223,8 +223,11 @@ bool PropagateMinMaxAmongArrays(Model* model,
     if (array.minmax) {
       CHECK(*array.minmax == *reference_minmax)
           << "Both the following arrays have minmax, and they disagree: "
-          << reference_array_name << " and " << array_name
-          << ". Expected that either only one of them would have minmax, or at "
+          << reference_array_name << " (" << reference_minmax->min << ","
+          << reference_minmax->max << ") and " << array_name << " ("
+          << array.minmax->min << "," << array.minmax->max
+          << "). Expected that either only one of them would have minmax, or "
+             "at "
              "least that they would agree.";
     } else {
       array.GetOrCreateMinMax() = *reference_minmax;
index 9fcc95e..7784558 100644 (file)
@@ -472,6 +472,44 @@ bool ChooseQuantizationForOperatorOutput(
 
   return true;
 }
+
+// Fixes array minmax info to match the quantization parameters.
+// This is required for when quantization parameters change for an array during
+// quantization (such as ChooseQuantizationForOperatorOutput).
+void FixMinMaxPostQuantization(ArrayDataType quantized_data_type,
+                               const QuantizationParams& quantization_params,
+                               MinMax* minmax) {
+  double qmin, qmax;
+  switch (quantized_data_type) {
+    case ArrayDataType::kUint8:
+      qmin = 0;
+      qmax = 255;
+      break;
+    case ArrayDataType::kInt16:
+      qmin = -32768;
+      qmax = 32767;
+      break;
+    default:
+      // No update required.
+      return;
+  }
+
+  // Compute new minmax values.
+  double min =
+      (qmin - quantization_params.zero_point) * quantization_params.scale;
+  double max =
+      (qmax - quantization_params.zero_point) * quantization_params.scale;
+
+  // If we are close to the existing minmax values don't bother changing them.
+  // This prevents propagating small floating point precision errors.
+  constexpr double kMinMaxThreshold = 1e-5;
+  const double width = max - min;
+  if (std::abs(min - minmax->min) > kMinMaxThreshold * width ||
+      std::abs(max - minmax->max) > kMinMaxThreshold * width) {
+    minmax->min = min;
+    minmax->max = max;
+  }
+}
 }  // namespace
 
 bool Quantize::Run(Model* model, std::size_t op_index) {
@@ -618,12 +656,19 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
                                             &quantization_params)) {
       changed = true;
       const auto& output = op.outputs[output_index];
+      auto& output_array = model->GetArray(output);
+
+      // Fix up the min/max information on the output array to match the chosen
+      // quantization parameters.
+      auto& output_minmax = output_array.GetMinMax();
+      FixMinMaxPostQuantization(quantized_data_type, quantization_params,
+                                &output_minmax);
+
       QuantizeArray(this, model, output, quantized_data_type,
                     quantization_params);
+
       const auto& dequantized_output =
           AvailableArrayName(*model, output + "_dequantized");
-      const auto& output_array = model->GetArray(output);
-      const auto& output_minmax = output_array.GetMinMax();
       auto& dequantized_output_array =
           model->GetOrCreateArray(dequantized_output);
       dequantized_output_array.data_type = ArrayDataType::kFloat;