Allow passing dummy/custom minmax information on a per-array basis,
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 26 Jan 2018 02:42:44 +0000 (18:42 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 26 Jan 2018 02:46:27 +0000 (18:46 -0800)
unlike the existing --default_ranges_{min,max} flags which only allowed
to set a single global value for all arrays.

This takes the form of a new embedded message in ModelFlags, which is
its own message so that it can be serialized separately. The command-line
interface is --arrays_extra_info_file=some_proto.pbtxt, i.e. we don't
try to make a command-line-flags-only interface, we mandate putting the info
in a file. The rationale is that users may want to specify custom minmax
for hundreds of arrays, so it would be cumbersome to have that all in a
command line.

This should be considered an experimental feature, in the sense that
in properly quantized models, minmax information is already embedded
in the graph (e.g. in FakeQuant nodes). This is an extension of the
existing --default_ranges_{min,max} feature which had turned out to be
too restrictive for many users.

PiperOrigin-RevId: 183326000

tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
tensorflow/contrib/lite/toco/BUILD
tensorflow/contrib/lite/toco/args.h
tensorflow/contrib/lite/toco/model_cmdline_flags.cc
tensorflow/contrib/lite/toco/model_flags.proto
tensorflow/contrib/lite/toco/toco_port.h
tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc
tensorflow/contrib/lite/toco/tooling_util.h

index f993fd6a00f054c670b247e886a1d9d2a34643e7..fc5897896477711c46b06f10003acb10783d12af 100644 (file)
@@ -1504,7 +1504,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
       << "*\n"
       << "* If you would like to carry on with the slow code, compile\n"
       << "* with this preprocessor token defined:\n"
-      << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
+      << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
       << "*\n"
       << "* The right thing to do, if you care about performance, is to add\n"
       << "* a new DepthwiseConv kernel to tfmini to cover your case.\n"
index 041e2487903c63572a7acda17f2f3ebc701be0c7..6fc7e5e3fdd4da8f8b224b8c10a6be8154204c94 100644 (file)
@@ -160,6 +160,7 @@ cc_library(
     ],
     deps = [
         # Placeholder for internal file dependency.
+        "@protobuf_archive//:protobuf_headers",
         "//tensorflow/core:framework_lite",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
index 8004a1a37ae48468e9bf22785ec02f8de54bf236..b97a4720a7c4e69f8b69574475d19e0522cfe86d 100644 (file)
@@ -208,6 +208,7 @@ struct ParsedModelFlags {
   Arg<bool> dump_graphviz_video = Arg<bool>(false);
   Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
   Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
+  Arg<string> arrays_extra_info_file;
 };
 
 // Flags that describe the operation you would like to do (what conversion
index 36520d9c55c83522b00a8d5e51a243f715731a83..4e2dec15a534607ef9207149a2e6061069eabcb1 100644 (file)
@@ -148,6 +148,12 @@ bool ParseModelFlagsFromCommandLineFlags(
            "ranging from 32 to 127. This is disallowed by default so as to "
            "catch common copy-and-paste issues where invisible unicode "
            "characters are unwittingly added to these strings."),
+      Flag(
+          "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
+          parsed_flags.arrays_extra_info_file.default_value(),
+          "Path to an optional file containing a serialized ArraysExtraInfo "
+          "proto allowing to pass extra information about arrays not specified "
+          "in the input model file, such as extra MinMax information."),
   };
   bool asked_for_help =
       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -365,6 +371,15 @@ void ReadModelFlagsFromCommandLineFlags(
       parsed_model_flags.allow_nonascii_arrays.value());
   model_flags->set_allow_nonexistent_arrays(
       parsed_model_flags.allow_nonexistent_arrays.value());
+
+  if (parsed_model_flags.arrays_extra_info_file.specified()) {
+    string arrays_extra_info_file_contents;
+    port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
+                            &arrays_extra_info_file_contents,
+                            port::file::Defaults());
+    ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
+                                      model_flags->mutable_arrays_extra_info());
+  }
 }
 
 ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
index 9070ddc88351faabdb1172f9601c0351728bfc46..e4b39b34e85e4d703c1b41cb68f8139abd1f6279 100644 (file)
@@ -87,6 +87,22 @@ message RnnState {
   optional int32 size = 3;
 }
 
+// An ArraysExtraInfo message stores a collection of additional Information
+// about arrays in a model, complementing the information in the model itself.
+// It is intentionally a separate message so that it may be serialized and
+// passed separately from the model. See --arrays_extra_info_file.
+//
+// A typical use case is to manually specify MinMax for specific arrays in a
+// model that does not already contain such MinMax information.
+message ArraysExtraInfo {
+  message Entry {
+    optional string name = 1;
+    optional float min = 2;
+    optional float max = 3;
+  }
+  repeated Entry entries = 1;
+}
+
 // ModelFlags encodes properties of a model that, depending on the file
 // format, may or may not be recorded in the model file. The purpose of
 // representing these properties in ModelFlags is to allow passing them
@@ -108,7 +124,7 @@ message RnnState {
 //   optional int32 input_dims = 11 [ default = 4];
 //   repeated int32 input_shape = 13;
 //
-// Next ID to USE: 18.
+// Next ID to USE: 19.
 message ModelFlags {
   // Information about the input arrays, i.e. the arrays from which input
   // activations will be read.
@@ -151,4 +167,8 @@ message ModelFlags {
   // catch common copy-and-paste issues where invisible unicode
   // characters are unwittingly added to these strings.
   optional bool allow_nonascii_arrays = 17;
+
+  // If set, this ArraysExtraInfo allows to pass extra information about arrays
+  // not specified in the input model file, such as extra MinMax information.
+  optional ArraysExtraInfo arrays_extra_info = 18;
 }
index 0572848cb5a998457cd669a2b0bce5fe8a0e15a2..4be3b5a0bf00ed204a1218545d9e66f7685a50d7 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 // can build and use on google internal environments and on OSX.
 
 #include <string>
+#include "google/protobuf/text_format.h"
 #include "tensorflow/contrib/lite/toco/format_port.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/platform.h"
@@ -75,6 +76,26 @@ void CopyToBuffer(const ::Cord& src, char* dest);
 #endif  // PLATFORM_GOOGLE
 void CopyToBuffer(const string& src, char* dest);
 }  // namespace port
+
+inline bool ParseFromStringOverload(const std::string& in,
+                                    TFLITE_PROTO_NS::Message* proto) {
+  return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
+}
+
+template <typename Proto>
+bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
+                                       Proto* proto) {
+  if (proto->ParseFromString(input_file_contents)) {
+    return true;
+  }
+
+  if (ParseFromStringOverload(input_file_contents, proto)) {
+    return true;
+  }
+
+  return false;
+}
+
 }  // namespace toco
 
 #endif  // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
index 720c33777d707994c6e1003bb1210eadd96bc8a8..727df1cc76ae332682a50db534e6bfa20ffc45ca 100644 (file)
@@ -193,6 +193,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
   }
 
   SetFinalDataTypeOnInputs(toco_flags, model);
+  UseArraysExtraInfo(model);
 
   // Remove unused ops before performing any other optimizations. This is to
   // stop optimizations from crossing the input/output boundaries. For example
@@ -232,6 +233,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
   transformations.Add(new ResolveConstantConcatenation);
   RunGraphTransformations(model, "general graph transformations",
                           transformations);
+
   if (quantize_output) {
     RunGraphTransformations(model, "pre-quantization graph transformations",
                             {new HardcodeMinMax, new DropFakeQuant});
index df785a5102afd9a3b3fec4e35684e196bfd0d935..3728d486597965df68ff427578265ca9774d9138 100644 (file)
@@ -1200,6 +1200,9 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
   model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
   model->flags.set_allow_nonexistent_arrays(
       model_flags.allow_nonexistent_arrays());
+
+  CHECK(!model->flags.has_arrays_extra_info());
+  *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
 }
 
 void CheckIsReadyForQuantization(const Model& model) {
@@ -1711,4 +1714,15 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
   }
 }
 
+void UseArraysExtraInfo(Model* model) {
+  for (const auto& entry : model->flags.arrays_extra_info().entries()) {
+    QCHECK(model->HasArray(entry.name()))
+        << "ArraysExtraInfo refers to non-existent array name: "
+        << entry.name();
+    auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax();
+    minmax.min = entry.min();
+    minmax.max = entry.max();
+  }
+}
+
 }  // namespace toco
index 5986d6364939e0f01b057ce3fb653b19fe8040cd..2ac51c7e5bb4653f47414a1d6f8e1ed8862ddf7e 100644 (file)
@@ -23,7 +23,6 @@ limitations under the License.
 #include <string>
 #include <vector>
 
-#include "google/protobuf/text_format.h"
 #include "tensorflow/core/platform/logging.h"
 #if TOCO_SUPPORT_PORTABLE_PROTOS
 #include "third_party/protobuf/src/google/protobuf/text_format.h"
@@ -84,25 +83,6 @@ void DumpGraphvizVideoFrame(const Model& model);
 void LogDump(int log_level, const string& message, const Model& model);
 void LogSummary(int log_level, const string& message, const Model& model);
 
-inline bool ParseFromStringOverload(const std::string& in,
-                                    TFLITE_PROTO_NS::Message* proto) {
-  return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
-}
-
-template <typename Proto>
-bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
-                                       Proto* proto) {
-  if (proto->ParseFromString(input_file_contents)) {
-    return true;
-  }
-
-  if (ParseFromStringOverload(input_file_contents, proto)) {
-    return true;
-  }
-
-  return false;
-}
-
 // TODO(b/36075966): Clean up when dims superseded by array shape.
 void ExtendShape(Shape* shape, int new_shape_size);
 
@@ -298,6 +278,8 @@ void CheckFinalDataTypesSatisfied(const Model& model);
 
 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type);
 
+void UseArraysExtraInfo(Model* model);
+
 }  // namespace toco
 
 #endif  // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_