Add basic support for quantized unfused LSTMs.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 22 Feb 2018 19:45:57 +0000 (11:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 19:52:35 +0000 (11:52 -0800)
PiperOrigin-RevId: 186650338

tensorflow/contrib/lite/toco/args.h
tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
tensorflow/contrib/lite/toco/toco_flags.proto
tensorflow/contrib/lite/toco/toco_tooling.cc

index b97a472..59a6115 100644 (file)
@@ -229,6 +229,7 @@ struct ParsedTocoFlags {
   // Deprecated flags
   Arg<string> input_type;
   Arg<string> input_types;
+  Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
   Arg<bool> drop_control_dependency = Arg<bool>(false);
 };
 
index 1b0be85..938d763 100644 (file)
@@ -125,6 +125,27 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
   return changed;
 }
 
+bool HardcodeMinMaxForSplit(Model* model, Operator* op) {
+  for (const auto& output : op->outputs) {
+    if (model->GetArray(output).minmax) {
+      LOG(WARNING) << "Skipping min-max setting for " << LogName(*op)
+                   << " because output " << output << " already has min-max.";
+      return false;
+    }
+  }
+  // Data is in second input.
+  auto& input_array = model->GetArray(op->inputs[1]);
+  if (!input_array.minmax) {
+    return false;
+  } else {
+    for (const auto& output : op->outputs) {
+      auto& array = model->GetArray(output);
+      array.GetOrCreateMinMax() = *input_array.minmax;
+    }
+    return true;
+  }
+}
+
 // The output of average or max pooling is within the same range as its input.
 bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) {
   auto& output_array = model->GetArray(op->outputs[0]);
@@ -296,6 +317,10 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
       changed = HardcodeMinMaxForConcatenation(model, op);
       break;
 
+    case OperatorType::kTensorFlowSplit:
+      changed = HardcodeMinMaxForSplit(model, op);
+      break;
+
     case OperatorType::kAveragePool:
     case OperatorType::kMaxPool:
       changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
index c5a62fd..0f67c2d 100644 (file)
@@ -112,6 +112,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
           "If true, ignore control dependency requirements in input TensorFlow "
           "GraphDef. Otherwise an error will be raised upon control dependency "
           "inputs."),
+      Flag("debug_disable_recurrent_cell_fusion",
+           parsed_flags.debug_disable_recurrent_cell_fusion.bind(),
+           parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
+           "If true, disable fusion of known identifiable cell subgraphs into "
+           "cells. This includes, for example, specific forms of LSTM cell."),
   };
   bool asked_for_help =
       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
index 3b9d7e2..3237147 100644 (file)
@@ -36,7 +36,8 @@ enum FileFormat {
 // are not normally encoded in model files and in general may not be thought
 // of as properties of models, instead describing how models are to be
 // processed in the context of the present tooling job.
-// Next Id: 13
+//
+// Next ID to use: 14.
 message TocoFlags {
   // Input file format
   optional FileFormat input_format = 1;
@@ -136,4 +137,8 @@ message TocoFlags {
   //    - Default to false if the output format is TENSORFLOW_GRAPHDEF.
   //    - Default to true in all other cases.
   optional bool drop_control_dependency = 12;
+
+  // Disables transformations that fuse subgraphs such as known LSTMs (not all
+  // LSTMs are identified).
+  optional bool debug_disable_recurrent_cell_fusion = 13;
 }
index 1b836fb..6fcaa95 100644 (file)
@@ -234,7 +234,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
   }
   transformations.Add(new ConvertPureConvToDepthwise);
   if (SupportsLstmCell(output_format)) {
-    transformations.Add(new IdentifyLstmCell);
+    if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
+      transformations.Add(new IdentifyLstmCell);
+    }
     if (output_format == TFLITE) {
       transformations.Add(new toco::SplitLstmCellInputs);
     } else {