Fix tf parser (#5794)
authorYao Wang <kevinthesunwy@gmail.com>
Sat, 13 Jun 2020 03:32:46 +0000 (20:32 -0700)
committerGitHub <noreply@github.com>
Sat, 13 Jun 2020 03:32:46 +0000 (20:32 -0700)
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/frontend/tensorflow_parser.py

index 5778b25..af09877 100644 (file)
@@ -1322,14 +1322,10 @@ def _shape():
 
 def _fill():
     def _impl(inputs, attr, params, mod):
-        output_shape = attr['_output_shapes'][0]
-        # Output shape must be defined to avoid errors. If any axis is not, we must
-        # try to compute its shape.
-        if output_shape is None or -1 in output_shape:
-            try:
-                output_shape = _expr.Constant(_infer_value(inputs[0], params, mod))
-            except Exception:
-                output_shape = inputs[0]
+        try:
+            output_shape = _infer_value(inputs[0], params, mod).asnumpy().tolist()
+        except Exception:
+            output_shape = inputs[0]
 
         return _op.full(inputs[1], output_shape, attr['T'].name)
     return _impl
index fdbb876..771aed0 100644 (file)
@@ -30,6 +30,10 @@ class TFParser(object):
     model_dir : tensorflow frozen pb file or a directory that contains saved
     model or checkpoints.
 
+    outputs : List of output tensor names (Optional)
+        Optional output node names. This will be protected for saved model
+        when we do remove training nodes.
+
     Examples
     --------
     .. code-block:: python
@@ -38,11 +42,12 @@ class TFParser(object):
         graphdef = parser.parse()
     """
 
-    def __init__(self, model_dir):
+    def __init__(self, model_dir, outputs=None):
         from tensorflow.core.framework import graph_pb2
         self._tmp_dir = util.tempdir()
         self._model_dir = model_dir
         self._graph = graph_pb2.GraphDef()
+        self._outputs = outputs or []
 
     def _set_graph(self, graph):
         """Set Graph"""
@@ -128,7 +133,8 @@ class TFParser(object):
             output_graph_def = graph_pb2.GraphDef()
             with open(output_graph_filename, "rb") as f:
                 output_graph_def.ParseFromString(f.read())
-            output_graph_def = graph_util.remove_training_nodes(output_graph_def)
+            output_graph_def = graph_util.remove_training_nodes(output_graph_def,
+                                                                protected_nodes=self._outputs)
             return output_graph_def
 
     def _load_ckpt(self):