unlike the existing --default_ranges_{min,max} flags which only allowed
to set a single global value for all arrays.
This takes the form of a new embedded message in ModelFlags, which is
its own message so that it can be serialized separately. The command-line
interface is --arrays_extra_info_file=some_proto.pbtxt, i.e. we don't
try to make a command-line-flags-only interface, we mandate putting the info
in a file. The rationale is that users may want to specify custom minmax
for hundreds of arrays, so it would be cumbersome to have that all in a
command line.
This should be considered an experimental feature, in the sense that
in properly quantized models, minmax information is already embedded
in the graph (e.g. in FakeQuant nodes). This is an extension of the
existing --default_ranges_{min,max} feature which had turned out to be
too restrictive for many users.
PiperOrigin-RevId:
183326000
<< "*\n"
<< "* If you would like to carry on with the slow code, compile\n"
<< "* with this preprocessor token defined:\n"
- << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
+ << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
<< "*\n"
<< "* The right thing to do, if you care about performance, is to add\n"
<< "* a new DepthwiseConv kernel to tfmini to cover your case.\n"
],
deps = [
# Placeholder for internal file dependency.
+ "@protobuf_archive//:protobuf_headers",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
Arg<bool> dump_graphviz_video = Arg<bool>(false);
Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
+ Arg<string> arrays_extra_info_file;
};
// Flags that describe the operation you would like to do (what conversion
"ranging from 32 to 127. This is disallowed by default so as to "
"catch common copy-and-paste issues where invisible unicode "
"characters are unwittingly added to these strings."),
+ Flag(
+ "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
+ parsed_flags.arrays_extra_info_file.default_value(),
+ "Path to an optional file containing a serialized ArraysExtraInfo "
+ "proto allowing to pass extra information about arrays not specified "
+ "in the input model file, such as extra MinMax information."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
parsed_model_flags.allow_nonascii_arrays.value());
model_flags->set_allow_nonexistent_arrays(
parsed_model_flags.allow_nonexistent_arrays.value());
+
+ if (parsed_model_flags.arrays_extra_info_file.specified()) {
+ string arrays_extra_info_file_contents;
+ port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
+ &arrays_extra_info_file_contents,
+ port::file::Defaults());
+ ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
+ model_flags->mutable_arrays_extra_info());
+ }
}
ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
optional int32 size = 3;
}
+// An ArraysExtraInfo message stores a collection of additional Information
+// about arrays in a model, complementing the information in the model itself.
+// It is intentionally a separate message so that it may be serialized and
+// passed separately from the model. See --arrays_extra_info_file.
+//
+// A typical use case is to manually specify MinMax for specific arrays in a
+// model that does not already contain such MinMax information.
+message ArraysExtraInfo {
+ message Entry {
+ optional string name = 1;
+ optional float min = 2;
+ optional float max = 3;
+ }
+ repeated Entry entries = 1;
+}
+
// ModelFlags encodes properties of a model that, depending on the file
// format, may or may not be recorded in the model file. The purpose of
// representing these properties in ModelFlags is to allow passing them
// optional int32 input_dims = 11 [ default = 4];
// repeated int32 input_shape = 13;
//
-// Next ID to USE: 18.
+// Next ID to USE: 19.
message ModelFlags {
// Information about the input arrays, i.e. the arrays from which input
// activations will be read.
// catch common copy-and-paste issues where invisible unicode
// characters are unwittingly added to these strings.
optional bool allow_nonascii_arrays = 17;
+
+ // If set, this ArraysExtraInfo allows to pass extra information about arrays
+ // not specified in the input model file, such as extra MinMax information.
+ optional ArraysExtraInfo arrays_extra_info = 18;
}
// can build and use on google internal environments and on OSX.
#include <string>
+#include "google/protobuf/text_format.h"
#include "tensorflow/contrib/lite/toco/format_port.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/platform.h"
#endif // PLATFORM_GOOGLE
void CopyToBuffer(const string& src, char* dest);
} // namespace port
+
+inline bool ParseFromStringOverload(const std::string& in,
+ TFLITE_PROTO_NS::Message* proto) {
+ return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
+}
+
+template <typename Proto>
+bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
+ Proto* proto) {
+ if (proto->ParseFromString(input_file_contents)) {
+ return true;
+ }
+
+ if (ParseFromStringOverload(input_file_contents, proto)) {
+ return true;
+ }
+
+ return false;
+}
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
}
SetFinalDataTypeOnInputs(toco_flags, model);
+ UseArraysExtraInfo(model);
// Remove unused ops before performing any other optimizations. This is to
// stop optimizations from crossing the input/output boundaries. For example
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);
+
if (quantize_output) {
RunGraphTransformations(model, "pre-quantization graph transformations",
{new HardcodeMinMax, new DropFakeQuant});
model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
model->flags.set_allow_nonexistent_arrays(
model_flags.allow_nonexistent_arrays());
+
+ CHECK(!model->flags.has_arrays_extra_info());
+ *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
}
void CheckIsReadyForQuantization(const Model& model) {
}
}
+void UseArraysExtraInfo(Model* model) {
+ for (const auto& entry : model->flags.arrays_extra_info().entries()) {
+ QCHECK(model->HasArray(entry.name()))
+ << "ArraysExtraInfo refers to non-existent array name: "
+ << entry.name();
+ auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax();
+ minmax.min = entry.min();
+ minmax.max = entry.max();
+ }
+}
+
} // namespace toco
#include <string>
#include <vector>
-#include "google/protobuf/text_format.h"
#include "tensorflow/core/platform/logging.h"
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/src/google/protobuf/text_format.h"
void LogDump(int log_level, const string& message, const Model& model);
void LogSummary(int log_level, const string& message, const Model& model);
-inline bool ParseFromStringOverload(const std::string& in,
- TFLITE_PROTO_NS::Message* proto) {
- return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
-}
-
-template <typename Proto>
-bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
- Proto* proto) {
- if (proto->ParseFromString(input_file_contents)) {
- return true;
- }
-
- if (ParseFromStringOverload(input_file_contents, proto)) {
- return true;
- }
-
- return false;
-}
-
// TODO(b/36075966): Clean up when dims superseded by array shape.
void ExtendShape(Shape* shape, int new_shape_size);
ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type);
+void UseArraysExtraInfo(Model* model);
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_