[TF][Relay][Op] Pass module when infer shape (#4287)
authorWei Chen <ipondering.weic@gmail.com>
Mon, 11 Nov 2019 19:22:14 +0000 (11:22 -0800)
committerHaichen Shen <shenhaichen@gmail.com>
Mon, 11 Nov 2019 19:22:14 +0000 (11:22 -0800)
* [TF][Relay][Op] Pass module when infer shape

* Fix lint

* Improve style

* Add test

python/tvm/relay/frontend/common.py
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 25ba0ef..473b77a 100644 (file)
@@ -451,20 +451,24 @@ def get_name(node):
     return name
 
 
-def infer_type(node):
+def infer_type(node, mod=None):
     """A method to infer the type of an intermediate node in the relay graph."""
-    mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
-    mod = _transform.InferType()(mod)
-    entry = mod["main"]
+    new_mod = _module.Module.from_expr(node)
+    if mod is not None:
+        new_mod.update(mod)
+    new_mod = _transform.InferType()(new_mod)
+    entry = new_mod["main"]
     return entry if isinstance(node, _expr.Function) else entry.body
 
-
-def infer_shape(inputs):
-    """A method to get the output shape of an intermediate node in the graph."""
-    out_type = infer_type(inputs)
-    out_shapes = get_const_tuple(out_type.checked_type.shape)
-    return out_shapes
-
+def infer_shape(inputs, mod=None):
+    """A method to get the output type of an intermediate node in the graph."""
+    out_type = infer_type(inputs, mod=mod)
+    checked_type = out_type.checked_type
+    if hasattr(checked_type, 'shape'):
+        # Regular operator that outputs tensors
+        return get_const_tuple(out_type.checked_type.shape)
+    # The return type is not a tensor, for example List
+    return checked_type
 
 def infer_channels(inputs, transpose=False):
     """A hack for getting 'channels' or 'units' since caffe2 does not provide
index 837b8d3..6a24e74 100644 (file)
@@ -90,6 +90,12 @@ def _get_list_param(params, input_node):
 def _get_tuple_param(params, input_node):
     return tuple(_get_param(params, input_node))
 
+def _need_module_for_shape_inference(op):
+    return op in ['StridedSlice']
+
+def _need_prelude_for_shape_inference(op):
+    return "TensorArray" in op
+
 def _rsqrt():
     def _impl(inputs, attr, params):
         inputs.append(tvm.relay.const(-0.5, attr['T'].name))
@@ -893,7 +899,7 @@ def _gather_nd():
     return _impl
 
 def _stridedSlice():
-    def _impl(inputs, attr, params):
+    def _impl(inputs, attr, params, mod):
         """Strided Slice.
         Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
         Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
@@ -976,7 +982,7 @@ def _stridedSlice():
         if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
             begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
         out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
-        out_shape = _infer_shape(out)
+        out_shape = _infer_shape(out, mod=mod)
         if not fshape_indices:
             fshape_indices = range(len(out_shape))
 
@@ -2169,7 +2175,8 @@ class GraphProto(object):
 
                 # Infer shapes even without specifying "add_shapes=True"
                 if output_shapes == [None]:
-                    out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]
+                    out_shapes = [_infer_shape(node_item, self._mod)
+                                  for node_item in self._nodes[node.name]]
                     self._output_shapes[node.name] = out_shapes
 
                 if self._output_shapes[node.name] and shape and node.name in shape:
@@ -2179,7 +2186,7 @@ class GraphProto(object):
             node_output = self._nodes[node.name]
             if shape and (not self._output_shapes[node.name][0]
                           or -1 in self._output_shapes[node.name][0]):
-                out_shapes = [_infer_shape(node_item) for node_item in node_output]
+                out_shapes = [_infer_shape(node_item, self._mod) for node_item in node_output]
                 self._output_shapes[node.name] = out_shapes
 
         out = []
@@ -2470,8 +2477,10 @@ class GraphProto(object):
         if op_name in identity_list:
             sym = get_relay_op(op_name)(*inputs, **attrs)
         elif op_name in convert_map:
-            if 'TensorArray' in op_name:
+            if _need_prelude_for_shape_inference(op_name):
                 sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
+            elif _need_module_for_shape_inference(op_name):
+                sym = convert_map[op_name](inputs, attrs, self._params, self._mod)
             else:
                 sym = convert_map[op_name](inputs, attrs, self._params)
 
index b8c980b..4790af3 100644 (file)
@@ -746,7 +746,8 @@ def test_tensor_array_concat():
                                  infer_shape=False, dynamic_size=False)
             ta2 = ta1.split(t, split_length)
             t = ta2.concat()
-            compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug')
+            out = tf.identity(t)
+            compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
     for dtype in tf_dtypes.keys():
         run(dtype)