[Relay][Prelude] Add more dtypes to tensor_t (#4233)
authorWei Chen <ipondering.weic@gmail.com>
Fri, 1 Nov 2019 20:37:58 +0000 (13:37 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 1 Nov 2019 20:37:58 +0000 (13:37 -0700)
python/tvm/relay/prelude.py
tests/python/frontend/tensorflow/test_forward.py

index d27ffe5..1625e19 100644 (file)
@@ -591,6 +591,14 @@ class Prelude:
         for global_def in GLOBAL_DEFS:
             setattr(self, global_def, self.mod.get_global_var(global_def))
 
-        for dtype in ['float32', 'int32']:
+        for dtype in ['float32',
+                      'float16',
+                      'float64',
+                      'int32',
+                      'uint8',
+                      'int8',
+                      'int16',
+                      'uint16',
+                      'int64']:
             tensor_array_ops = TensorArrayOps(self, dtype)
             tensor_array_ops.register()
index b17ec12..4554293 100644 (file)
@@ -48,6 +48,17 @@ def convert_to_list(x):
         x = [x]
     return x
 
+tf_dtypes = {
+    'float32': tf.float32,
+    'float16': tf.float16,
+    'float64': tf.float64,
+    'int32': tf.int32,
+    'uint8' : tf.uint8,
+    'int8': tf.int8,
+    'int16': tf.int16,
+    'uint16': tf.uint16,
+    'int64': tf.int64,
+}
 
 def vmobj_to_list(o):
     if isinstance(o, tvm.relay.backend.vmobj.Tensor):
@@ -626,34 +637,24 @@ def test_forward_squeeze():
 def test_tensor_array_constructor():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype = {
-                'float32': tf.float32,
-                'int32': tf.int32
-            }[dtype_str]
-            t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(
-                dtype_str), dtype=dtype)
-            t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(
-                dtype_str), dtype=dtype)
-            ta1 = tf.TensorArray(dtype=dtype, size=2,
-                                 infer_shape=False, dynamic_size=False)
+            dtype = tf_dtypes[dtype_str]
+            t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
+            t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
+            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
             ta2 = ta1.write(0, t)
             ta3 = ta2.write(1, t2)
             out = ta3.read(0)
             g = tf.get_default_graph()
             compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
-    run('float32')
-    run('int32')
+    for dtype in tf_dtypes.keys():
+        run(dtype)
 
 
 def test_tensor_array_scatter():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype = {
-                'float32': tf.float32,
-                'int32': tf.int32
-            }[dtype_str]
-            t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(
-                dtype_str), dtype=dtype)
+            dtype =  tf_dtypes[dtype_str]
+            t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
             indices = tf.constant([2, 1, 0])
             ta1 = tf.TensorArray(dtype=dtype, size=3,
                                  infer_shape=False, dynamic_size=False)
@@ -663,12 +664,10 @@ def test_tensor_array_scatter():
             out2 = ta2.read(2)
             g = tf.get_default_graph()
             compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
-            compare_tf_with_tvm(
-                [], [], ['TensorArrayReadV3_1:0'], mode='debug')
-            compare_tf_with_tvm(
-                [], [], ['TensorArrayReadV3_2:0'], mode='debug')
-    run('float32')
-    run('int32')
+            compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
+            compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
+    for dtype in tf_dtypes.keys():
+        run(dtype)
 
 # TODO(wweic): Fix gather issue with PartialEvaluate
 # def test_tensor_array_gather():
@@ -687,12 +686,8 @@ def test_tensor_array_scatter():
 def test_tensor_array_split():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype = {
-                'float32': tf.float32,
-                'int32': tf.int32
-            }[dtype_str]
-            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [
-                            6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
+            dtype =  tf_dtypes[dtype_str]
+            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
             split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
             ta1 = tf.TensorArray(dtype=dtype, size=4,
                                  infer_shape=False, dynamic_size=False)
@@ -703,50 +698,38 @@ def test_tensor_array_split():
             out3 = ta2.read(3)
             g = tf.get_default_graph()
             compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
-            compare_tf_with_tvm(
-                [], [], ['TensorArrayReadV3_1:0'], mode='debug')
-            compare_tf_with_tvm(
-                [], [], ['TensorArrayReadV3_2:0'], mode='debug')
-            compare_tf_with_tvm(
-                [], [], ['TensorArrayReadV3_3:0'], mode='debug')
-    run('float32')
-    run('int32')
+            compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
+            compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
+            compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
+    for dtype in tf_dtypes.keys():
+        run(dtype)
 
 
 def test_tensor_array_concat():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype = {
-                'float32': tf.float32,
-                'int32': tf.int32
-            }[dtype_str]
-            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [
-                            6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
+            dtype = tf_dtypes[dtype_str]
+            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
             split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
             ta1 = tf.TensorArray(dtype=dtype, size=4,
                                  infer_shape=False, dynamic_size=False)
             ta2 = ta1.split(t, split_length)
             t = ta2.concat()
-            compare_tf_with_tvm(
-                [], [], ['TensorArrayConcatV3:0'], mode='debug')
-    run('float32')
-    run('int32')
+            compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug')
+    for dtype in tf_dtypes.keys():
+        run(dtype)
 
 
 def test_tensor_array_size():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype = {
-                'float32': tf.float32,
-                'int32': tf.int32
-            }[dtype_str]
-            ta1 = tf.TensorArray(dtype=dtype, size=2,
-                                 infer_shape=False, dynamic_size=False)
+            dtype =  tf_dtypes[dtype_str]
+            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
             out = ta1.size()
             g = tf.get_default_graph()
             compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
-    run('float32')
-    run('int32')
+    for dtype in tf_dtypes.keys():
+        run(dtype)
 
 #######################################################################
 # ConcatV2