Further small support for quantized unfused LSTMs.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 7 Mar 2018 21:09:07 +0000 (13:09 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Mar 2018 21:13:33 +0000 (13:13 -0800)
PiperOrigin-RevId: 188221169

tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
tensorflow/contrib/lite/toco/tooling_util.cc

index 7731675..6c3e5fd 100644 (file)
@@ -222,7 +222,50 @@ ArrayDataType GetQuantizedDataType(const Array& array,
     default:
       LOG(FATAL) << "Unhandled final quantization type "
                  << static_cast<int>(array.final_data_type);
-      return default_type;
+  }
+}
+
+void GetQuantizationParams(ArrayDataType data_type,
+                           const ModelFlags& model_flags, const MinMax& minmax,
+                           QuantizationParams* quantization_params) {
+  switch (data_type) {
+    case ArrayDataType::kInt8:
+      GetQuantizationParamsFromMinMax<ArrayDataType::kInt8>(
+          model_flags, minmax, quantization_params);
+      break;
+    case ArrayDataType::kUint8:
+      GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+          model_flags, minmax, quantization_params);
+      break;
+    case ArrayDataType::kInt16:
+      GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
+          model_flags, minmax, quantization_params);
+      break;
+    case ArrayDataType::kUint16:
+      GetQuantizationParamsFromMinMax<ArrayDataType::kUint16>(
+          model_flags, minmax, quantization_params);
+      break;
+    case ArrayDataType::kInt32:
+      GetQuantizationParamsFromMinMax<ArrayDataType::kInt32>(
+          model_flags, minmax, quantization_params);
+      break;
+    case ArrayDataType::kUint32:
+      GetQuantizationParamsFromMinMax<ArrayDataType::kUint32>(
+          model_flags, minmax, quantization_params);
+      break;
+    case ArrayDataType::kInt64:
+      GetQuantizationParamsFromMinMax<ArrayDataType::kInt64>(
+          model_flags, minmax, quantization_params);
+      break;
+    case ArrayDataType::kUint64:
+      GetQuantizationParamsFromMinMax<ArrayDataType::kUint64>(
+          model_flags, minmax, quantization_params);
+      break;
+    case ArrayDataType::kFloat:
+    case ArrayDataType::kNone:
+    default:
+      LOG(FATAL) << "Unhandled final quantization type "
+                 << static_cast<int>(data_type);
   }
 }
 
@@ -284,16 +327,16 @@ bool ChooseQuantizationForOperatorInput(
 
   if (op.type == OperatorType::kLstmCell) {
     if (input_index == LstmCellOperator::PREV_STATE_INPUT) {
-      GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
-          model->flags, minmax, quantization_params);
       *quantized_data_type = ArrayDataType::kInt16;
+      GetQuantizationParams(*quantized_data_type, model->flags, minmax,
+                            quantization_params);
       return true;
     }
   }
 
-  GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
-                                                         quantization_params);
   *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
+  GetQuantizationParams(*quantized_data_type, model->flags, minmax,
+                        quantization_params);
   transformation->AddMessageF(
       "For input array %s with min=%g"
       ", max=%g"
@@ -416,15 +459,15 @@ bool ChooseQuantizationForOperatorOutput(
   if (op.type == OperatorType::kLstmCell) {
     if (output_index == LstmCellOperator::STATE_OUTPUT ||
         output_index == LstmCellOperator::ACTIV_TEMP) {
-      GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
-          model->flags, minmax, quantization_params);
       *quantized_data_type = ArrayDataType::kInt16;
+      GetQuantizationParams(*quantized_data_type, model->flags, minmax,
+                            quantization_params);
       return true;
     }
   }
-  GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
-                                                         quantization_params);
   *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
+  GetQuantizationParams(*quantized_data_type, model->flags, minmax,
+                        quantization_params);
   transformation->AddMessageF(
       "For output array %s with min=%g, max=%g"
       ", chose to quantize as %s with zero_point=%d"
index f92e107..48aad89 100644 (file)
@@ -1809,7 +1809,10 @@ bool IsDiscardableArray(const Model& model, const string& array_name) {
 void CheckFinalDataTypesSatisfied(const Model& model) {
   for (const auto& array_entry : model.GetArrayMap()) {
     const auto& array = *array_entry.second;
-    if (array.final_data_type != ArrayDataType::kNone) {
+    // If the final data type is int16, the data type may be float, for example
+    // after dequantization.
+    if (array.final_data_type != ArrayDataType::kNone &&
+        array.final_data_type != ArrayDataType::kInt16) {
       CHECK(array.final_data_type == array.data_type)
           << "Array \"" << array_entry.first
           << "\" has mis-matching actual and final data types ("