Adds commandline option (toco_compatible, bool) that makes the optimize_for_inference...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 22:45:16 +0000 (15:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 22:49:04 +0000 (15:49 -0700)
This change does not alter existing behavior (the boolean is set to false by default).

PiperOrigin-RevId: 191660378

tensorflow/python/tools/optimize_for_inference.py
tensorflow/python/tools/optimize_for_inference_lib.py

index 902748d..dac6a06 100644 (file)
@@ -87,7 +87,9 @@ def main(unused_args):
   output_graph_def = optimize_for_inference_lib.optimize_for_inference(
       input_graph_def,
       FLAGS.input_names.split(","),
-      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)
+      FLAGS.output_names.split(","),
+      FLAGS.placeholder_type_enum,
+      FLAGS.toco_compatible)
 
   if FLAGS.frozen_graph:
     f = gfile.FastGFile(FLAGS.output, "w")
@@ -138,6 +140,14 @@ def parse_args():
       type=int,
       default=dtypes.float32.as_datatype_enum,
       help="The AttrValue enum to use for placeholders.")
+  parser.add_argument(
+      "--toco_compatible",
+      type=bool,
+      default=False,
+      help="""\
+      If true, only use ops compatible with Tensorflow
+      Lite Optimizing Converter.\
+      """)
   return parser.parse_known_args()
 
 
index 9c19271..bb90d1c 100644 (file)
@@ -87,7 +87,7 @@ EPSILON_ATTR = {
 
 
 def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
-                           placeholder_type_enum):
+                           placeholder_type_enum, toco_compatible=False):
   """Applies a series of inference optimizations on the input graph.
 
   Args:
@@ -98,6 +98,8 @@ def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
       results.
     placeholder_type_enum: The AttrValue enum for the placeholder data type, or
         a list that specifies one value per input node name.
+    toco_compatible: Boolean, if True, only runs optimizations that result in
+      TOCO compatible graph operations (default=False).
 
   Returns:
     An optimized version of the input graph.
@@ -110,8 +112,9 @@ def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
   optimized_graph_def = graph_util.remove_training_nodes(
       optimized_graph_def, output_node_names)
   optimized_graph_def = fold_batch_norms(optimized_graph_def)
-  optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
-                                             output_node_names)
+  if not toco_compatible:
+    optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
+                                               output_node_names)
   ensure_graph_is_valid(optimized_graph_def)
   return optimized_graph_def