From da66104f3d5e3b5ba5e79d0beccd1b91b9578bee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 22 Feb 2018 11:45:57 -0800 Subject: [PATCH] Add basic support for quantized unfused LSTMs. PiperOrigin-RevId: 186650338 --- tensorflow/contrib/lite/toco/args.h | 1 + .../toco/graph_transformations/hardcode_min_max.cc | 25 ++++++++++++++++++++++ tensorflow/contrib/lite/toco/toco_cmdline_flags.cc | 5 +++++ tensorflow/contrib/lite/toco/toco_flags.proto | 7 +++++- tensorflow/contrib/lite/toco/toco_tooling.cc | 4 +++- 5 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index b97a472..59a6115 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -229,6 +229,7 @@ struct ParsedTocoFlags { // Deprecated flags Arg input_type; Arg input_types; + Arg debug_disable_recurrent_cell_fusion = Arg(false); Arg drop_control_dependency = Arg(false); }; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 1b0be85..938d763 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -125,6 +125,27 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { return changed; } +bool HardcodeMinMaxForSplit(Model* model, Operator* op) { + for (const auto& output : op->outputs) { + if (model->GetArray(output).minmax) { + LOG(WARNING) << "Skipping min-max setting for " << LogName(*op) + << " because output " << output << " already has min-max."; + return false; + } + } + // Data is in second input. + auto& input_array = model->GetArray(op->inputs[1]); + if (!input_array.minmax) { + return false; + } else { + for (const auto& output : op->outputs) { + auto& array = model->GetArray(output); + array.GetOrCreateMinMax() = *input_array.minmax; + } + return true; + } +} + // The output of average or max pooling is within the same range as its input. bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) { auto& output_array = model->GetArray(op->outputs[0]); @@ -296,6 +317,10 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForConcatenation(model, op); break; + case OperatorType::kTensorFlowSplit: + changed = HardcodeMinMaxForSplit(model, op); + break; + case OperatorType::kAveragePool: case OperatorType::kMaxPool: changed = HardcodeMinMaxForAverageOrMaxPool(model, op); diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index c5a62fd..0f67c2d 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -112,6 +112,11 @@ bool ParseTocoFlagsFromCommandLineFlags( "If true, ignore control dependency requirements in input TensorFlow " "GraphDef. Otherwise an error will be raised upon control dependency " "inputs."), + Flag("debug_disable_recurrent_cell_fusion", + parsed_flags.debug_disable_recurrent_cell_fusion.bind(), + parsed_flags.debug_disable_recurrent_cell_fusion.default_value(), + "If true, disable fusion of known identifiable cell subgraphs into " + "cells. This includes, for example, specific forms of LSTM cell."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 3b9d7e2..3237147 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -36,7 +36,8 @@ enum FileFormat { // are not normally encoded in model files and in general may not be thought // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. -// Next Id: 13 +// +// Next ID to use: 14. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -136,4 +137,8 @@ message TocoFlags { // - Default to false if the output format is TENSORFLOW_GRAPHDEF. // - Default to true in all other cases. optional bool drop_control_dependency = 12; + + // Disables transformations that fuse subgraphs such as known LSTMs (not all + // LSTMs are identified). + optional bool debug_disable_recurrent_cell_fusion = 13; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 1b836fb..6fcaa95 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -234,7 +234,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) { } transformations.Add(new ConvertPureConvToDepthwise); if (SupportsLstmCell(output_format)) { - transformations.Add(new IdentifyLstmCell); + if (!toco_flags.debug_disable_recurrent_cell_fusion()) { + transformations.Add(new IdentifyLstmCell); + } if (output_format == TFLITE) { transformations.Add(new toco::SplitLstmCellInputs); } else { -- 2.7.4