From a764216776465a5385596ca83af6edf3da72c504 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 3 Apr 2018 15:03:44 -0700 Subject: [PATCH] Accept toco ModelFlags protos on the command line. PiperOrigin-RevId: 191505886 --- tensorflow/contrib/lite/toco/args.h | 1 + .../contrib/lite/toco/model_cmdline_flags.cc | 24 +++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 52c7892..39e49bc 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -211,6 +211,7 @@ struct ParsedModelFlags { Arg allow_nonexistent_arrays = Arg(false); Arg allow_nonascii_arrays = Arg(false); Arg arrays_extra_info_file; + Arg model_flags_file; }; // Flags that describe the operation you would like to do (what conversion diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 4264f21..245eb52 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -160,6 +160,11 @@ bool ParseModelFlagsFromCommandLineFlags( "Path to an optional file containing a serialized ArraysExtraInfo " "proto allowing to pass extra information about arrays not specified " "in the input model file, such as extra MinMax information."), + Flag("model_flags_file", parsed_flags.model_flags_file.bind(), + parsed_flags.model_flags_file.default_value(), + "Path to an optional file containing a serialized ModelFlags proto. " + "Options specified on the command line will override the values in " + "the proto."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -182,7 +187,24 @@ void ReadModelFlagsFromCommandLineFlags( const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) { toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet"); -// "batch" flag only exists internally + // Load proto containing the initial model flags. + // Additional flags specified on the command line will overwrite the values. + if (parsed_model_flags.model_flags_file.specified()) { + string model_flags_file_contents; + QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(), + &model_flags_file_contents, + port::file::Defaults()) + .ok()) + << "Specified --model_flags_file=" + << parsed_model_flags.model_flags_file.value() + << " was not found or could not be read"; + QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents, + model_flags)) + << "Specified --model_flags_file=" + << parsed_model_flags.model_flags_file.value() + << " could not be parsed"; + } + #ifdef PLATFORM_GOOGLE CHECK(!((base::SpecifiedOnCommandLine("batch") && parsed_model_flags.variable_batch.specified()))) -- 2.7.4