quantized LSTM support improvements
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Apr 2018 21:05:03 +0000 (14:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 21:07:30 +0000 (14:07 -0700)
PiperOrigin-RevId: 191794956

tensorflow/contrib/lite/toco/export_tensorflow.cc
tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
tensorflow/contrib/lite/toco/model_cmdline_flags.cc
tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc
tensorflow/contrib/lite/toco/tooling_util.h

index 5d51431..4a77196 100644 (file)
@@ -37,6 +37,7 @@ limitations under the License.
 
 using tensorflow::DT_BOOL;
 using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT16;
 using tensorflow::DT_INT32;
 using tensorflow::DT_INT64;
 using tensorflow::DT_UINT8;
@@ -1868,6 +1869,9 @@ void AddPlaceholder(const string& name, ArrayDataType type,
     case ArrayDataType::kInt64:
       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
       break;
+    case ArrayDataType::kInt16:
+      (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
+      break;
     default:
       LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
   }
index 935da9f..183b3d3 100644 (file)
@@ -78,15 +78,21 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
   image_input_op->outputs = {dequantized_input_name};
   model->operators.emplace(model->operators.begin(), image_input_op);
 
-  CHECK(input_array.final_data_type == ArrayDataType::kUint8);
-  input_array.data_type = ArrayDataType::kUint8;
   dequantized_input_array.data_type = ArrayDataType::kFloat;
   const auto& input_minmax = input_array.GetMinMax();
   auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax();
   dequantized_input_minmax = input_minmax;
   auto& input_qparams = input_array.GetOrCreateQuantizationParams();
-  GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
-                                                         &input_qparams);
+  input_array.data_type = input_array.final_data_type;
+  if (input_array.data_type == ArrayDataType::kUint8) {
+    GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
+                                                           &input_qparams);
+  } else if (input_array.data_type == ArrayDataType::kInt16) {
+    GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(input_minmax,
+                                                           &input_qparams);
+  } else {
+    LOG(FATAL) << "unhandled data type";
+  }
 
   transformation->AddMessageF(
       "Created %s"
index 245eb52..0fa6e85 100644 (file)
@@ -402,9 +402,10 @@ void ReadModelFlagsFromCommandLineFlags(
 
   if (parsed_model_flags.arrays_extra_info_file.specified()) {
     string arrays_extra_info_file_contents;
-    port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
-                            &arrays_extra_info_file_contents,
-                            port::file::Defaults());
+    CHECK(port::file::GetContents(
+              parsed_model_flags.arrays_extra_info_file.value(),
+              &arrays_extra_info_file_contents, port::file::Defaults())
+              .ok());
     ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
                                       model_flags->mutable_arrays_extra_info());
   }
index 76e9a27..96c5ebd 100644 (file)
@@ -130,20 +130,26 @@ bool SupportsPreallocatedWorkspace(FileFormat format) {
 }
 
 bool IsRealValued(toco::ArrayDataType type) {
+  // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
+  // for quantized real-number values, and no other integer type is ever used
+  // for that. This is dirty, should be resolved as part of a more general push
+  // to more explicitly distinguish between true-integers and
+  // integers used as quantized values representing real numbers.
   return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
-                           type == toco::ArrayDataType::kUint8);
+                           type == toco::ArrayDataType::kUint8 ||
+                           type == toco::ArrayDataType::kInt16);
 }
 
 void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
   const FileFormat output_format = toco_flags.output_format();
   ArrayDataType type;
-  if (toco_flags.has_inference_input_type()) {
+  if (!SupportsQuantization(output_format)) {
+    // Data type is implicitly float for non-quantized formats
+    type = ArrayDataType::kFloat;
+  } else if (toco_flags.has_inference_input_type()) {
     type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
   } else if (toco_flags.has_inference_type()) {
     type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
-  } else if (!SupportsQuantization(output_format)) {
-    // Data type is implicitly float for non-quantized formats
-    type = ArrayDataType::kFloat;
   } else {
     // Nothing to do. Data types stay as-is.
     return;
@@ -198,11 +204,6 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
 }
 
 void Transform(const TocoFlags& toco_flags, Model* model) {
-  // Clean up after import.
-  SetFinalDataTypeOnInputs(toco_flags, model);
-  UseArraysExtraInfo(model);
-  FinishBuildingRNNStates(model);
-
   const FileFormat output_format = toco_flags.output_format();
   const IODataType inference_type = toco_flags.inference_type();
 
@@ -215,6 +216,11 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
         << "Quantized inference is not allowed with float inputs.";
   }
 
+  // Clean up after import.
+  SetFinalDataTypeOnInputs(toco_flags, model);
+  UseArraysExtraInfo(model, quantize_output);
+  FinishBuildingRNNStates(model);
+
   // Remove unused ops before performing any other optimizations. This is to
   // stop optimizations from crossing the input/output boundaries. For example
   // this will stop BatchNorm fusing if the output node is in between a conv
index 56fa8f4..61d08fa 100644 (file)
@@ -1378,12 +1378,22 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
     const float mean_value = input_array_proto.mean_value();
     const float std_value = input_array_proto.std_value();
     MinMax input_minmax;
-    input_minmax.min = (0.f - mean_value) / std_value;
-    input_minmax.max = (255.f - mean_value) / std_value;
+    float qmin = 0, qmax = 255;
+    if (input_array.data_type == ArrayDataType::kInt16) {
+      qmin = -32768;
+      qmax = 32767;
+    }
+    input_minmax.min = (qmin - mean_value) / std_value;
+    input_minmax.max = (qmax - mean_value) / std_value;
     if (input_array.minmax) {
       if (input_array_proto.has_mean_value() ||
           input_array_proto.has_std_value()) {
-        CHECK(input_minmax == *input_array.minmax)
+        const double width = input_minmax.max - input_minmax.min;
+        const double kMinMaxAllowedDiff = 1e-6 * width;
+        CHECK(std::abs(input_minmax.min - input_array.minmax->min) <
+                  kMinMaxAllowedDiff &&
+              std::abs(input_minmax.max - input_array.minmax->max) <
+                  kMinMaxAllowedDiff)
             << input_minmax.min << ", " << input_minmax.max
             << " != " << input_array.minmax->min << ", "
             << input_array.minmax->max;
@@ -2000,7 +2010,7 @@ void FinishBuildingRNNStates(Model* model) {
   }
 }
 
-void UseArraysExtraInfo(Model* model) {
+void UseArraysExtraInfo(Model* model, bool quantize_output) {
   for (const auto& entry : model->flags.arrays_extra_info().entries()) {
     if (!model->HasArray(entry.name())) {
       continue;
@@ -2012,7 +2022,7 @@ void UseArraysExtraInfo(Model* model) {
       minmax.min = entry.min();
       minmax.max = entry.max();
     }
-    if (entry.has_data_type()) {
+    if (entry.has_data_type() && quantize_output) {
       array.final_data_type =
           ConvertIODataTypeToArrayDataType(entry.data_type());
     }
index 259ee7f..dfd8117 100644 (file)
@@ -285,7 +285,7 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type);
 // already quantized, then case (a) should hold.
 void FinishBuildingRNNStates(Model* model);
 
-void UseArraysExtraInfo(Model* model);
+void UseArraysExtraInfo(Model* model, bool quantize_output);
 
 }  // namespace toco