Add a command line parameter to toco to change the way toco rescales input and output...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 00:35:00 +0000 (17:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 00:37:43 +0000 (17:37 -0700)
PiperOrigin-RevId: 191825756

tensorflow/contrib/lite/toco/args.h
tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
tensorflow/contrib/lite/toco/model_cmdline_flags.cc
tensorflow/contrib/lite/toco/model_flags.proto
tensorflow/contrib/lite/toco/tooling_util.cc

index 39e49bc..7a7059e 100644 (file)
@@ -202,6 +202,7 @@ struct ParsedModelFlags {
   Arg<toco::IntList> input_shape;
   Arg<toco::StringMapList> rnn_states;
   Arg<toco::StringMapList> model_checks;
+  Arg<bool> change_concat_input_ranges = Arg<bool>(true);
   // Debugging output options.
   // TODO(benoitjacob): these shouldn't be ModelFlags.
   Arg<string> graphviz_first_array;
index 23c9e32..437e30a 100644 (file)
@@ -95,30 +95,37 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
   overall_minmax.min = overall_min;
   overall_minmax.max = overall_max;
   bool changed = false;
-  for (const auto& input : op->inputs) {
-    auto& array = model->GetArray(input);
-    if (!array.minmax) {
-      changed = true;
-    } else if (!(overall_minmax == array.GetMinMax())) {
-      changed = true;
-      LOG(WARNING)
-          << "Tweaking the MinMax of array " << input << ", which is "
-          << "an input to " << LogName(*op) << ", because we want all inputs "
-          << "and outputs of a Concatenation operator to have the same MinMax "
-          << "so that it can be implemented as a pure byte-copy, no "
-             "arithmetic.";
+  if (model->flags.change_concat_input_ranges()) {
+    for (const auto& input : op->inputs) {
+      auto& array = model->GetArray(input);
+      if (!array.minmax) {
+        changed = true;
+      } else if (!(overall_minmax == array.GetMinMax())) {
+        changed = true;
+        LOG(WARNING)
+            << "Tweaking the MinMax of array " << input << ", which is "
+            << "an input to " << LogName(*op) << ", because we want all inputs "
+            << "and outputs of a Concatenation operator to have the same "
+            << "MinMax so that it can be implemented as a pure byte-copy, no "
+               "arithmetic.";
+      }
+      array.GetOrCreateMinMax() = overall_minmax;
     }
-    array.GetOrCreateMinMax() = overall_minmax;
   }
   if (!output.minmax) {
     changed = true;
   } else if (!(overall_minmax == output.GetMinMax())) {
-    changed = true;
-    LOG(WARNING)
-        << "Tweaking the MinMax of the output array of " << LogName(*op)
-        << ", because we want all inputs "
-        << "and outputs of a Concatenation operator to have the same MinMax "
-        << "so that it can be implemented as a pure byte-copy, no arithmetic.";
+    if (model->flags.change_concat_input_ranges()) {
+      changed = true;
+      LOG(WARNING)
+          << "Tweaking the MinMax of the output array of " << LogName(*op)
+          << ", because we want all inputs "
+          << "and outputs of a Concatenation operator to have the same MinMax "
+          << "so that it can be implemented as a pure byte-copy, no "
+          << "arithmetic.";
+    } else {
+      return false;
+    }
   }
   output.GetOrCreateMinMax() = overall_minmax;
 
index 7784558..5b1268f 100644 (file)
@@ -431,7 +431,8 @@ bool ChooseQuantizationForOperatorOutput(
       (op.type == OperatorType::kSpaceToDepth) ||
       (op.type == OperatorType::kTensorFlowReshape) ||
       (op.type == OperatorType::kTensorFlowSplit) ||
-      (op.type == OperatorType::kConcatenation)) {
+      (op.type == OperatorType::kConcatenation &&
+       model->flags.change_concat_input_ranges())) {
     int data_input_index = 0;
     if (op.type == OperatorType::kTensorFlowSplit) {
       data_input_index = 1;
index 0fa6e85..7bbeab7 100644 (file)
@@ -165,6 +165,11 @@ bool ParseModelFlagsFromCommandLineFlags(
            "Path to an optional file containing a serialized ModelFlags proto. "
            "Options specified on the command line will override the values in "
            "the proto."),
+      Flag("change_concat_input_ranges",
+           parsed_flags.change_concat_input_ranges.bind(),
+           parsed_flags.change_concat_input_ranges.default_value(),
+           "Boolean to change the behavior of min/max ranges for inputs and"
+           " output of the concat operators."),
   };
   bool asked_for_help =
       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -399,6 +404,8 @@ void ReadModelFlagsFromCommandLineFlags(
       parsed_model_flags.allow_nonascii_arrays.value());
   model_flags->set_allow_nonexistent_arrays(
       parsed_model_flags.allow_nonexistent_arrays.value());
+  model_flags->set_change_concat_input_ranges(
+      parsed_model_flags.change_concat_input_ranges.value());
 
   if (parsed_model_flags.arrays_extra_info_file.specified()) {
     string arrays_extra_info_file_contents;
index 835dea4..d23e80c 100644 (file)
@@ -128,7 +128,7 @@ message ArraysExtraInfo {
 //   optional int32 input_dims = 11 [ default = 4];
 //   repeated int32 input_shape = 13;
 //
-// Next ID to USE: 19.
+// Next ID to USE: 20.
 message ModelFlags {
   // Information about the input arrays, i.e. the arrays from which input
   // activations will be read.
@@ -175,4 +175,8 @@ message ModelFlags {
   // If set, this ArraysExtraInfo allows to pass extra information about arrays
   // not specified in the input model file, such as extra MinMax information.
   optional ArraysExtraInfo arrays_extra_info = 18;
+
+  // When set to false, toco will not change the input ranges and the output
+  // ranges of concat operator to the overlap of all input ranges.
+  optional bool change_concat_input_ranges = 19 [default = true];
 }
index 61d08fa..b72f5fa 100644 (file)
@@ -1413,7 +1413,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
       CHECK(input_array.shape().dims_size());
     }
   }
-
+  model->flags.set_change_concat_input_ranges(
+      model_flags.change_concat_input_ranges());
   model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
   model->flags.set_allow_nonexistent_arrays(
       model_flags.allow_nonexistent_arrays());