output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
FLAGS.input_names.split(","),
- FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)
+ FLAGS.output_names.split(","),
+ FLAGS.placeholder_type_enum,
+ FLAGS.toco_compatible)
if FLAGS.frozen_graph:
f = gfile.FastGFile(FLAGS.output, "w")
type=int,
default=dtypes.float32.as_datatype_enum,
help="The AttrValue enum to use for placeholders.")
+ parser.add_argument(
+ "--toco_compatible",
+ type=bool,
+ default=False,
+ help="""\
+ If true, only use ops compatible with Tensorflow
+ Lite Optimizing Converter.\
+ """)
return parser.parse_known_args()
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
- placeholder_type_enum):
+ placeholder_type_enum, toco_compatible=False):
"""Applies a series of inference optimizations on the input graph.
Args:
results.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
+ toco_compatible: Boolean, if True, only runs optimizations that result in
+ TOCO compatible graph operations (default=False).
Returns:
An optimized version of the input graph.
optimized_graph_def = graph_util.remove_training_nodes(
optimized_graph_def, output_node_names)
optimized_graph_def = fold_batch_norms(optimized_graph_def)
- optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
- output_node_names)
+ if not toco_compatible:
+ optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
+ output_node_names)
ensure_graph_is_valid(optimized_graph_def)
return optimized_graph_def