LSTM support: Quantized types, quantization params for 16-bit unfused LSTMs.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Feb 2018 00:53:40 +0000 (16:53 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 01:07:29 +0000 (17:07 -0800)
PiperOrigin-RevId: 186697357

tensorflow/contrib/lite/toco/dump_graphviz.cc
tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
tensorflow/contrib/lite/toco/model_flags.proto
tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc
tensorflow/contrib/lite/toco/types.proto

index 2184e8f..c835274 100644 (file)
@@ -193,12 +193,12 @@ NodeProperties GetPropertiesForArray(const Model& model,
   }
 
   if (array.minmax) {
-    AppendF(&node_properties.label, "\\nMinMax: [%.3g, %.3g]",
+    AppendF(&node_properties.label, "\\nMinMax: [%.7g, %.7g]",
             array.minmax->min, array.minmax->max);
   }
 
   if (array.quantization_params) {
-    AppendF(&node_properties.label, "\\nQuantization: %.3g * (x - %d)",
+    AppendF(&node_properties.label, "\\nQuantization: %7g * (x - %d)",
             array.quantization_params->scale,
             array.quantization_params->zero_point);
   }
index d7f804e..7731675 100644 (file)
@@ -100,7 +100,13 @@ void QuantizeArray(GraphTransformation* transformation, Model* model,
 void QuantizeArray(GraphTransformation* transformation, Model* model,
                    const string& name, ArrayDataType quantized_data_type,
                    const QuantizationParams& quantization_params) {
-  switch (quantized_data_type) {
+  ArrayDataType adjusted_data_type = quantized_data_type;
+  auto& array = model->GetArray(name);
+  if (array.final_data_type == ArrayDataType::kInt16) {
+    adjusted_data_type = array.final_data_type;
+  }
+
+  switch (adjusted_data_type) {
     case ArrayDataType::kUint8:
       return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
                                                   quantization_params);
@@ -166,6 +172,60 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
                 "proceed with quantization.";
 }
 
+struct QuantizationPoints {
+  int64 min_value;
+  int64 max_value;
+  int64 central_value;
+};
+
+template <ArrayDataType A>
+QuantizationPoints GetQuantizationPoints() {
+  QuantizationPoints qp;
+  using Integer = DataType<A>;
+  qp.min_value = std::numeric_limits<Integer>::min();
+  qp.max_value = std::numeric_limits<Integer>::max();
+  // eg [-128,127]...
+  qp.central_value = (qp.min_value / 2 +        // -128 -> -64.
+                      (qp.max_value - 1) / 2 +  // 127 -> 63.
+                      1);
+  return qp;
+}
+
+QuantizationPoints GetQuantizationPoints(ArrayDataType data_type) {
+  switch (data_type) {
+    case ArrayDataType::kUint8:
+      return GetQuantizationPoints<ArrayDataType::kUint8>();
+    case ArrayDataType::kInt16:
+      return GetQuantizationPoints<ArrayDataType::kInt16>();
+    case ArrayDataType::kInt32:
+      return GetQuantizationPoints<ArrayDataType::kInt32>();
+    default:
+      LOG(FATAL) << "Unhandled case.";
+  }
+}
+
+ArrayDataType GetQuantizedDataType(const Array& array,
+                                   ArrayDataType default_type) {
+  switch (array.final_data_type) {
+    case ArrayDataType::kInt8:
+    case ArrayDataType::kUint8:
+    case ArrayDataType::kInt16:
+    case ArrayDataType::kUint16:
+    case ArrayDataType::kInt32:
+    case ArrayDataType::kUint32:
+    case ArrayDataType::kInt64:
+    case ArrayDataType::kUint64:
+      return array.final_data_type;
+    case ArrayDataType::kFloat:
+    case ArrayDataType::kNone:
+      return default_type;
+    default:
+      LOG(FATAL) << "Unhandled final quantization type "
+                 << static_cast<int>(array.final_data_type);
+      return default_type;
+  }
+}
+
 bool ChooseQuantizationForOperatorInput(
     GraphTransformation* transformation, Model* model, const Operator& op,
     std::size_t input_index, ArrayDataType* quantized_data_type,
@@ -212,7 +272,7 @@ bool ChooseQuantizationForOperatorInput(
     const auto input_weights_scale = input_weights.quantization_params->scale;
     quantization_params->scale = input_activations_scale * input_weights_scale;
     quantization_params->zero_point = 0;
-    *quantized_data_type = ArrayDataType::kInt32;
+    *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kInt32);
     transformation->AddMessageF(
         "Input array %s is a bias vector. Choosing quantization params "
         "accordingly.",
@@ -233,14 +293,14 @@ bool ChooseQuantizationForOperatorInput(
 
   GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
                                                          quantization_params);
+  *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
   transformation->AddMessageF(
       "For input array %s with min=%g"
       ", max=%g"
-      ", chose to quantize as uint8 with zero_point=%d"
+      ", chose to quantize as %s with zero_point=%d"
       ", scale=%g",
-      input, minmax.min, minmax.max, quantization_params->zero_point,
-      quantization_params->scale);
-  *quantized_data_type = ArrayDataType::kUint8;
+      input, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type),
+      quantization_params->zero_point, quantization_params->scale);
   return true;
 }
 
@@ -262,16 +322,18 @@ bool IsExactlyRepresentable(double real_value, ArrayDataType data_type,
   return true;
 }
 
+// Quantized data type is preset to the type of the input before this function.
 bool ChooseHardcodedQuantizationForOperatorOutput(
-    const Operator& op, ArrayDataType* quantized_data_type,
+    const Operator& op, const Array& array, ArrayDataType* quantized_data_type,
     QuantizationParams* quantization_params) {
   if (op.type == OperatorType::kL2Normalization) {
     // L2Normalization has range: [-1, 1].
     // 0 should be exactly representable, as values will typically be centered
     // around 0, with many values near 0.
-    *quantized_data_type = ArrayDataType::kUint8;
-    quantization_params->zero_point = 128;
-    quantization_params->scale = 1. / 128.;
+    *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
+    const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
+    quantization_params->zero_point = qp.central_value;
+    quantization_params->scale = 1. / (qp.central_value - qp.min_value);
     CHECK(
         IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
     return true;
@@ -284,18 +346,20 @@ bool ChooseHardcodedQuantizationForOperatorOutput(
     // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and
     // the glueing of the two halves of the graph will only be seamless if we
     // are accurately representing logistic(0) == 0.5.
-    *quantized_data_type = ArrayDataType::kUint8;
+    *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
+    const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
     quantization_params->zero_point = 0;
-    quantization_params->scale = 1. / 256.;
+    quantization_params->scale = 1. / (qp.max_value + 1);
     CHECK(IsExactlyRepresentable(0.5, *quantized_data_type,
                                  *quantization_params));
     return true;
   }
   if (op.type == OperatorType::kTanh) {
     // Tanh has the range: [-1, 1].
-    *quantized_data_type = ArrayDataType::kUint8;
-    quantization_params->zero_point = 128;
-    quantization_params->scale = 1. / 128.;
+    *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
+    const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type);
+    quantization_params->zero_point = qp.central_value;
+    quantization_params->scale = 1. / (qp.central_value - qp.min_value);
     // 0 should be exactly representable, as values will typically be centered
     // around 0, with many values near 0.
     CHECK(
@@ -314,8 +378,9 @@ bool ChooseQuantizationForOperatorOutput(
   if (array.data_type != ArrayDataType::kFloat) {
     return false;
   }
-  if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type,
-                                                   quantization_params)) {
+  *quantized_data_type = model->GetArray(op.inputs[0]).data_type;
+  if (ChooseHardcodedQuantizationForOperatorOutput(
+          op, array, quantized_data_type, quantization_params)) {
     transformation->AddMessageF(
         "Output array %s is produced by a %s operator. Choosing fixed "
         "quantization params accordingly.",
@@ -323,12 +388,21 @@ bool ChooseQuantizationForOperatorOutput(
     return true;
   }
   if ((op.type == OperatorType::kDepthToSpace) ||
-      (op.type == OperatorType::kSpaceToDepth)) {
-    // DepthToSpace and SpaceToDepth should preserve the quantization parameters
-    // of the input array, as these are simple reshape operations.
-    const auto& input_quantization_params =
-        model->GetArray(op.inputs[0]).GetQuantizationParams();
-    *quantized_data_type = ArrayDataType::kUint8;
+      (op.type == OperatorType::kSpaceToDepth) ||
+      (op.type == OperatorType::kTensorFlowReshape) ||
+      (op.type == OperatorType::kTensorFlowSplit) ||
+      (op.type == OperatorType::kConcatenation)) {
+    int data_input_index = 0;
+    if (op.type == OperatorType::kTensorFlowSplit) {
+      data_input_index = 1;
+    }
+    // Copying and rearrangement ops should preserve the quantization parameters
+    // of the input array.
+    const auto& input_array = model->GetArray(op.inputs[data_input_index]);
+    const auto& input_quantization_params = input_array.GetQuantizationParams();
+    *quantized_data_type =
+        GetQuantizedDataType(input_array, ArrayDataType::kUint8);
+    *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type);
     quantization_params->zero_point = input_quantization_params.zero_point;
     quantization_params->scale = input_quantization_params.scale;
 
@@ -350,13 +424,13 @@ bool ChooseQuantizationForOperatorOutput(
   }
   GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
                                                          quantization_params);
-  *quantized_data_type = ArrayDataType::kUint8;
+  *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
   transformation->AddMessageF(
       "For output array %s with min=%g, max=%g"
-      ", chose to quantize as uint8 with zero_point=%d"
+      ", chose to quantize as %s with zero_point=%d"
       ", scale=%g",
-      output, minmax.min, minmax.max, quantization_params->zero_point,
-      quantization_params->scale);
+      output, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type),
+      quantization_params->zero_point, quantization_params->scale);
 
   return true;
 }
index e4b39b3..867b86f 100644 (file)
@@ -96,9 +96,11 @@ message RnnState {
 // model that does not already contain such MinMax information.
 message ArraysExtraInfo {
   message Entry {
+    // Next ID to use: 5.
     optional string name = 1;
     optional float min = 2;
     optional float max = 3;
+    optional IODataType data_type = 4;
   }
   repeated Entry entries = 1;
 }
index 2153bab..a09a3c4 100644 (file)
@@ -199,7 +199,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
   const IODataType inference_type = toco_flags.inference_type();
 
   const bool quantize_output =
-      SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8;
+      SupportsQuantization(output_format) &&
+      (inference_type == QUANTIZED_UINT8 || inference_type == QUANTIZED_INT16);
 
   if (quantize_output) {
     QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
index eec35b7..9e72582 100644 (file)
@@ -1801,6 +1801,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
       return ArrayDataType::kFloat;
     case QUANTIZED_UINT8:
       return ArrayDataType::kUint8;
+    case QUANTIZED_INT16:
+      return ArrayDataType::kInt16;
     case INT32:
       return ArrayDataType::kInt32;
     case INT64:
@@ -1832,9 +1834,17 @@ void UseArraysExtraInfo(Model* model) {
     QCHECK(model->HasArray(entry.name()))
         << "ArraysExtraInfo refers to non-existent array name: "
         << entry.name();
-    auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax();
-    minmax.min = entry.min();
-    minmax.max = entry.max();
+    auto& array = model->GetArray(entry.name());
+    auto& minmax = array.GetOrCreateMinMax();
+    if (entry.has_min() || entry.has_max()) {
+      CHECK_EQ(entry.has_min(), entry.has_max());
+      minmax.min = entry.min();
+      minmax.max = entry.max();
+    }
+    if (entry.has_data_type()) {
+      array.final_data_type =
+          ConvertIODataTypeToArrayDataType(entry.data_type());
+    }
   }
 }
 
index 318fd4b..03bd615 100644 (file)
@@ -34,4 +34,7 @@ enum IODataType {
 
   // String, not quantized
   STRING = 5;
+
+  // Int16, quantized
+  QUANTIZED_INT16 = 6;
 }