Fix the TF tutorial to run against TF2.0 and TF1.x (#4104)
authorEric Platon <zaraki@gmx.com>
Tue, 12 Nov 2019 23:52:24 +0000 (00:52 +0100)
committerYizhi Liu <liuyizhi@apache.org>
Tue, 12 Nov 2019 23:52:24 +0000 (15:52 -0800)
* WIP Run the TF tutorial on TF2

* Remove debugger statement.

* Complete the support for TF2.0's `resize`.

TF2.0 adds a `half_pixel_centers` attribute to the `resize` function in
the image API. This commit completes the hooks in Relay's TF frontend.

At the point of this commit, no new test yet. Also, this commit
addresses solely the `resize` change. Other commits address other
changes in TF2.0.

* Support TF2.0 in the tutorial by using the compat API.

This looks cleaner than trying to detect the TF version.

* Use the TF compat API, so as to support TF2.0.

This is a direct change, relying on the compat API provided by the TF
team.

This code will last as long as the compat API exists, so a
"proper" support for TF1.x and 2.x will require more work in some
future.

* Partial support for EXPLICIT padding introduced in TF2.0.

Explicit padding is a special case in TF2.0 (see reference linked
below). Some models are serialized with that mode, and break TF support
in TVM.

Support is *partial* as EXPLICIT falls back to set padding on the
Relay op, which only supports 2 values. At some point, padding may need
to be extended to support 4 values, but that is out of scope of this
support commit.

Reference on EXPLICIT padding: https://github.com/tensorflow/tensorflow/commit/ec81825aaf7e848d9f8ddffdf1e0d20aebe9172c#diff-1d1c0bb0a880f85b6164f71dbb2f446e

* Guard on checking for optional TF2.0 attribute.

* Do not expect Relay to implement TF-specific attributes.

The `half_pixel_centers` attribute is a new feature in TF2.0. Earlier
commits of mine mistakenly introduce them in the Relay API. This is
probably not what Relay is expected to support, and the semantics of
`half_pixel_centers` is unclear (to me, at least) at this point.

* Remove unclear comment.

CR https://github.com/dmlc/tvm/pull/4104#discussion_r338705742

Addresses #4104

* Changes after review.

Complying without understanding the rationale for now.

* Fix the arguments set mistakenly.

An argument ignored for the wrong operation.

python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/testing/tf.py
tutorials/frontend/from_tensorflow.py

index 2a65678..35c857a 100644 (file)
@@ -307,8 +307,10 @@ def _conv(opname):
         use_bias = len(inputs) == 3
         channel_axis = 1 if attr['data_format'] == "NCHW" else 3
 
+        # Ignore the new attributes from TF2.0, for now.
         out = AttrCvt(
             op_name=_dimension_picker('conv'),
+            ignores=['explicit_paddings'],
             transforms={
                 'kernel_shape': 'kernel_size',
                 'data_format': 'data_layout',
@@ -405,8 +407,9 @@ def _resize_bilinear():
         # NHWC
         attr['layout'] = 'NHWC'
 
+        # Ignore the new attributes from TF2.0, for now.
         return AttrCvt(op_name="resize",
-                       ignores=['Tdim'],
+                       ignores=['Tdim', 'half_pixel_centers'],
                        extras={'method': "bilinear"})(inputs, attr)
     return _impl
 
index a56e6fe..79d0d82 100644 (file)
@@ -80,7 +80,7 @@ def AddShapesToGraphDef(session, out_node):
 
     """
 
-    graph_def = tf.graph_util.convert_variables_to_constants(
+    graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
         session,
         session.graph.as_graph_def(add_shapes=True),
         [out_node],
@@ -112,13 +112,13 @@ class NodeLookup(object):
             dict from integer node ID to human-readable string.
 
         """
-        if not tf.gfile.Exists(uid_lookup_path):
+        if not tf.compat.v1.io.gfile.exists(uid_lookup_path):
             tf.logging.fatal('File does not exist %s', uid_lookup_path)
-        if not tf.gfile.Exists(label_lookup_path):
+        if not tf.compat.v1.io.gfile.exists(label_lookup_path):
             tf.logging.fatal('File does not exist %s', label_lookup_path)
 
         # Loads mapping from string UID to human-readable string
-        proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
+        proto_as_ascii_lines = tf.compat.v1.gfile.GFile(uid_lookup_path).readlines()
         uid_to_human = {}
         p = re.compile(r'[n\d]*[ \S,]*')
         for line in proto_as_ascii_lines:
@@ -129,7 +129,7 @@ class NodeLookup(object):
 
         # Loads mapping from string UID to integer node ID.
         node_id_to_uid = {}
-        proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
+        proto_as_ascii = tf.compat.v1.gfile.GFile(label_lookup_path).readlines()
         for line in proto_as_ascii:
             if line.startswith('  target_class:'):
                 target_class = int(line.split(': ')[1])
@@ -209,7 +209,7 @@ def get_workload(model_path, model_sub_path=None):
         path_model = download_testdata(model_url, model_path, module='tf')
 
     # Creates graph from saved graph_def.pb.
-    with tf.gfile.FastGFile(path_model, 'rb') as f:
+    with tf.compat.v1.gfile.FastGFile(path_model, 'rb') as f:
         graph_def = tf.GraphDef()
         graph_def.ParseFromString(f.read())
         graph = tf.import_graph_def(graph_def, name='')
@@ -299,7 +299,7 @@ def _create_ptb_vocabulary(data_dir):
     file_name = 'ptb.train.txt'
     def _read_words(filename):
         """Read the data for creating vocabulary"""
-        with tf.gfile.GFile(filename, "r") as f:
+        with tf.compat.v1.gfile.GFile(filename, "r") as f:
             return f.read().encode("utf-8").decode("utf-8").replace("\n", "<eos>").split()
 
     def _build_vocab(filename):
index 34865f0..2c109cb 100644 (file)
@@ -89,14 +89,14 @@ label_path = download_testdata(label_map_url, label_map, module='data')
 # ------------
 # Creates tensorflow graph definition from protobuf file.
 
-with tf.gfile.FastGFile(model_path, 'rb') as f:
-    graph_def = tf.GraphDef()
+with tf.compat.v1.gfile.GFile(model_path, 'rb') as f:
+    graph_def = tf.compat.v1.GraphDef()
     graph_def.ParseFromString(f.read())
     graph = tf.import_graph_def(graph_def, name='')
     # Call the utility to import the graph definition into default graph.
     graph_def = tf_testing.ProcessGraphDefParam(graph_def)
     # Add shapes to the graph.
-    with tf.Session() as sess:
+    with tf.compat.v1.Session() as sess:
         graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
 
 ######################################################################
@@ -187,8 +187,8 @@ for node_id in top_k:
 def create_graph():
     """Creates a graph from saved GraphDef file and returns a saver."""
     # Creates graph from saved graph_def.pb.
-    with tf.gfile.FastGFile(model_path, 'rb') as f:
-        graph_def = tf.GraphDef()
+    with tf.compat.v1.gfile.GFile(model_path, 'rb') as f:
+        graph_def = tf.compat.v1.GraphDef()
         graph_def.ParseFromString(f.read())
         graph = tf.import_graph_def(graph_def, name='')
         # Call the utility to import the graph definition into default graph.
@@ -206,14 +206,14 @@ def run_inference_on_image(image):
     -------
         Nothing
     """
-    if not tf.gfile.Exists(image):
+    if not tf.compat.v1.io.gfile.exists(image):
         tf.logging.fatal('File does not exist %s', image)
-    image_data = tf.gfile.FastGFile(image, 'rb').read()
+    image_data = tf.compat.v1.gfile.GFile(image, 'rb').read()
 
     # Creates graph from saved GraphDef.
     create_graph()
 
-    with tf.Session() as sess:
+    with tf.compat.v1.Session() as sess:
         softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
         predictions = sess.run(softmax_tensor,
                                {'DecodeJpeg/contents:0': image_data})