// Deprecated flags
Arg<string> input_type;
Arg<string> input_types;
+ Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
Arg<bool> drop_control_dependency = Arg<bool>(false);
};
return changed;
}
+bool HardcodeMinMaxForSplit(Model* model, Operator* op) {
+ for (const auto& output : op->outputs) {
+ if (model->GetArray(output).minmax) {
+ LOG(WARNING) << "Skipping min-max setting for " << LogName(*op)
+ << " because output " << output << " already has min-max.";
+ return false;
+ }
+ }
+ // Data is in second input.
+ auto& input_array = model->GetArray(op->inputs[1]);
+ if (!input_array.minmax) {
+ return false;
+ } else {
+ for (const auto& output : op->outputs) {
+ auto& array = model->GetArray(output);
+ array.GetOrCreateMinMax() = *input_array.minmax;
+ }
+ return true;
+ }
+}
+
// The output of average or max pooling is within the same range as its input.
bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) {
auto& output_array = model->GetArray(op->outputs[0]);
changed = HardcodeMinMaxForConcatenation(model, op);
break;
+ case OperatorType::kTensorFlowSplit:
+ changed = HardcodeMinMaxForSplit(model, op);
+ break;
+
case OperatorType::kAveragePool:
case OperatorType::kMaxPool:
changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
"If true, ignore control dependency requirements in input TensorFlow "
"GraphDef. Otherwise an error will be raised upon control dependency "
"inputs."),
+ Flag("debug_disable_recurrent_cell_fusion",
+ parsed_flags.debug_disable_recurrent_cell_fusion.bind(),
+ parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
+ "If true, disable fusion of known identifiable cell subgraphs into "
+ "cells. This includes, for example, specific forms of LSTM cell."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
// are not normally encoded in model files and in general may not be thought
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
-// Next Id: 13
+//
+// Next ID to use: 14.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
// - Default to false if the output format is TENSORFLOW_GRAPHDEF.
// - Default to true in all other cases.
optional bool drop_control_dependency = 12;
+
+ // Disables transformations that fuse subgraphs such as known LSTMs (not all
+ // LSTMs are identified).
+ optional bool debug_disable_recurrent_cell_fusion = 13;
}
}
transformations.Add(new ConvertPureConvToDepthwise);
if (SupportsLstmCell(output_format)) {
- transformations.Add(new IdentifyLstmCell);
+ if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
+ transformations.Add(new IdentifyLstmCell);
+ }
if (output_format == TFLITE) {
transformations.Add(new toco::SplitLstmCellInputs);
} else {