[Relay][Pass]Improve memory_allocation pass to support multiple i/o dynamic kernels...
authorYao Wang <kevinthesunwy@gmail.com>
Sat, 4 Jan 2020 06:19:00 +0000 (22:19 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Sat, 4 Jan 2020 06:19:00 +0000 (22:19 -0800)
* Add more shape funcs

* Fix test

* Enhance test_any_concat

* Fix pylint

* Minor fix test

* Fix pylint

* Minor refactor

* Add test any for elemwise

python/tvm/relay/memory_alloc.py
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/_transform.py
tests/python/relay/test_any.py

index 9de5431..a8d1a30 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=no-else-return,invalid-name,len-as-condition
+# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
 """
 A pass for manifesting explicit memory allocations.
 """
@@ -173,6 +173,8 @@ class ManifestAllocPass(ExprMutator):
             new_args = [self.visit(arg) for arg in call.args]
             ins = expr.Tuple(new_args)
             ret_type = call.checked_type
+            view = LinearizeRetType(ret_type)
+            out_types = view.unpack()
 
             is_dynamic = ret_type.is_dynamic()
             # TODO(@jroesch): restore this code, more complex then it seems
@@ -180,26 +182,37 @@ class ManifestAllocPass(ExprMutator):
             #     is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
 
             if is_dynamic:
-                assert isinstance(ret_type, ty.TensorType)
                 shape_func_ins = []
                 engine = compile_engine.get()
                 cfunc = engine.lower_shape_func(call.op, self.target_host)
                 input_states = cfunc.shape_func_param_states
 
                 is_inputs = []
+                input_pos = 0
                 for i, (arg, state) in enumerate(zip(new_args, input_states)):
                     state = int(state)
                     # Pass Shapes
                     if state == 2:
-                        sh_of = self.visit(self.shape_of(arg))
-                        shape_func_ins.append(
-                            scope.let("in_shape_{0}".format(i), sh_of))
+                        if isinstance(arg.type_annotation, ty.TupleType):
+                            for j in range(len(arg.type_annotation.fields)):
+                                let_in_arg = scope.let("in_arg_{0}".format(input_pos + j),
+                                                       expr.TupleGetItem(arg, j))
+                                sh_of = self.visit(self.shape_of(let_in_arg))
+                                shape_func_ins.append(
+                                    scope.let("in_shape_{0}".format(input_pos + j), sh_of))
+                            input_pos += len(arg.type_annotation.fields)
+                        else:
+                            sh_of = self.visit(self.shape_of(arg))
+                            shape_func_ins.append(
+                                scope.let("in_shape_{0}".format(input_pos), sh_of))
+                            input_pos += 1
                         is_inputs.append(0)
                     # Pass Inputs
                     elif state == 1:
                         new_arg = self.visit(arg)
                         shape_func_ins.append(
-                            scope.let("in_shape_{0}".format(i), new_arg))
+                            scope.let("in_shape_{0}".format(input_pos), new_arg))
+                        input_pos += 1
                         is_inputs.append(1)
                     # TODO(@jroesch): handle 3rd case
                     else:
@@ -219,9 +232,6 @@ class ManifestAllocPass(ExprMutator):
 
                 scope.let("shape_func", shape_call)
 
-                out_types = []
-                out_types.append(call.checked_type)
-
                 storages = []
                 for out_shape, out_type in zip(out_shapes, out_types):
                     size = self.compute_storage_in_relay(
@@ -242,15 +252,13 @@ class ManifestAllocPass(ExprMutator):
                     alloc = scope.let("out_{i}".format(i=i), alloc)
                     outs.append(alloc)
 
-                invoke = self.invoke_tvm(call.op, ins, expr.Tuple(outs))
+                tuple_outs = expr.Tuple(outs)
+                invoke = self.invoke_tvm(call.op, ins, tuple_outs)
                 scope.let("", invoke)
-                return outs[0]
+                return outs[0] if len(outs) == 1 else tuple_outs
             else:
-                view = LinearizeRetType(ret_type)
-                out_tys = view.unpack()
-
                 outs = []
-                for i, out_ty in enumerate(out_tys):
+                for i, out_ty in enumerate(out_types):
                     out = self.make_static_allocation(scope, out_ty, i)
                     outs.append(out)
 
index 114ff2a..b4a3697 100644 (file)
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 import topi
+from topi.util import get_const_tuple
 from .op import register_compute, register_schedule, register_pattern, register_shape_func
 from .op import schedule_injective, OpPattern
 from ...hybrid import script
+from ...api import convert
 
 schedule_broadcast = schedule_injective
 schedule_elemwise = schedule_injective
@@ -120,20 +122,20 @@ def _cast_shape_function(x):
 def cast_shape_func(attrs, inputs, out_ndims):
     return [_cast_shape_function(*inputs)]
 
-# shape func
 @script
-def _full_shape_func(x):
-    out_ndim = len(x)
+def _full_shape_func(shape):
+    out_ndim = len(shape)
     out = output_tensor((out_ndim,), "int64")
     for i in const_range(out_ndim):
-        out[i] = x[i]
+        out[i] = int64(shape[i])
     return out
 
 def full_shape_func(attrs, inputs, out_ndims):
     """
     Shape func for zeros, zeros_like, ones, ones_like.
     """
-    return [_full_shape_func(*inputs)]
+    shape = get_const_tuple(attrs.shape)
+    return [_full_shape_func(convert(shape))]
 
 @script
 def _broadcast_shape_func(x, y, ndim):
@@ -177,9 +179,11 @@ def elemwise_shape_func(attrs, inputs, _):
 
 register_shape_func("cast", False, cast_shape_func)
 register_shape_func("zeros", False, full_shape_func)
-register_shape_func("zeros_like", False, full_shape_func)
+register_shape_func("zeros_like", False, elemwise_shape_func)
 register_shape_func("ones", False, full_shape_func)
-register_shape_func("ones_like", False, full_shape_func)
+register_shape_func("ones_like", False, elemwise_shape_func)
+register_shape_func("full", False, full_shape_func)
+register_shape_func("full_like", False, elemwise_shape_func)
 
 register_shape_func("add", False, broadcast_shape_func)
 register_shape_func("subtract", False, broadcast_shape_func)
@@ -196,6 +200,9 @@ register_shape_func("less", False, broadcast_shape_func)
 register_shape_func("less_equal", False, broadcast_shape_func)
 register_shape_func("greater", False, broadcast_shape_func)
 register_shape_func("greater_equal", False, broadcast_shape_func)
+register_shape_func("maximum", False, broadcast_shape_func)
+register_shape_func("minimum", False, broadcast_shape_func)
 
 register_shape_func("sqrt", False, elemwise_shape_func)
 register_shape_func("negative", False, elemwise_shape_func)
+register_shape_func("exp", False, elemwise_shape_func)
index de708fb..9f32c25 100644 (file)
@@ -452,24 +452,8 @@ def transpose_shape_func(attrs, inputs, _):
 @script
 def _squeeze_shape_func(data_shape, keep_axes):
     out = output_tensor((len(keep_axes),), "int64")
-    if len(keep_axes) == 0:
-        out_size = 0
-        for i in const_range(data_shape.shape[0]):
-            if data_shape[i] != 1:
-                out_size += 1
-
-        if out_size == 0:
-            out_size = 1
-        out = output_tensor((out_size,), "int64")
-        out[0] = int64(1)
-        pos = 0
-        for i in const_range(data_shape.shape[0]):
-            if data_shape[i] != 1:
-                out[pos] = data_shape[i]
-                pos += 1
-    else:
-        for i in const_range(len(keep_axes)):
-            out[i] = data_shape[keep_axes[i]]
+    for i in const_range(len(keep_axes)):
+        out[i] = data_shape[keep_axes[i]]
 
     return out
 
@@ -485,7 +469,16 @@ def squeeze_shape_func(attrs, inputs, _):
             if i not in axis:
                 keep_axes.append(i)
 
-    return [_squeeze_shape_func(inputs[0], convert(keep_axes))]
+    # Due to current relay type system, it is possible even
+    # a static kernel function needs shape function. To handle
+    # this case, we allow axis to be None in squeeze shape func
+    # for now.
+    # TODO(kevinthesun): Enhance relay type system to avoid this.
+    if keep_axes:
+        out = _squeeze_shape_func(inputs[0], convert(keep_axes))
+    else:
+        out = tvm.compute((), lambda *indices: 0)
+    return [out]
 
 @script
 def _reshape_like_shape_func(target_shape):
@@ -527,9 +520,56 @@ def _tile_shape_func(data, reps, ndim, tndim, rndim):
 
 @_reg.register_shape_func("tile", False)
 def tile_shape_func(attrs, inputs, _):
+    """
+    Shape function for tile op.
+    """
     reps = get_const_tuple(attrs.reps)
     ndim = inputs[0].shape[0].value
     rndim = len(reps)
     tndim = ndim if ndim > rndim else rndim
     return [_tile_shape_func(inputs[0], convert(reps), convert(ndim),
                              convert(tndim), convert(rndim))]
+
+@script
+def _split_shape_func(data_shape, index, indices_or_sections, axis):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    if len(indices_or_sections) == 1:
+        for i in const_range(data_shape.shape[0]):
+            if i == axis:
+                out[i] = ceil_div(data_shape[axis], indices_or_sections[0])
+            else:
+                out[i] = data_shape[i]
+    else:
+        start = int64(0)
+        if index > 0:
+            start = int64(indices_or_sections[index - 1])
+        end = data_shape[axis]
+        if index < len(indices_or_sections):
+            end = int64(indices_or_sections[index])
+        for i in const_range(data_shape.shape[0]):
+            if i == axis:
+                out[i] = end - start
+            else:
+                out[i] = data_shape[i]
+    return out
+
+@_reg.register_shape_func("split", False)
+def split_shape_func(attrs, inputs, _):
+    """
+    Shape function for split op.
+    """
+    if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)):
+        indices_or_sections = get_const_int(attrs.indices_or_sections)
+    else:
+        indices_or_sections = get_const_tuple(attrs.indices_or_sections)
+
+    axis = get_const_int(attrs.axis)
+
+    num_out = indices_or_sections if isinstance(indices_or_sections, int) \
+        else len(indices_or_sections) + 1
+    if isinstance(indices_or_sections, int):
+        indices_or_sections = [indices_or_sections]
+    return [_split_shape_func(inputs[0],
+                              convert(i),
+                              convert(indices_or_sections),
+                              convert(axis)) for i in range(num_out)]
index d7246da..a30326c 100644 (file)
@@ -59,6 +59,22 @@ def test_any_broadcast():
     verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
     verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
 
+def verify_any_elemwise(x_shape, x_np_shape, op, np_op):
+    dtype = 'float32'
+    x = relay.var('x', shape=x_shape, dtype=dtype)
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x], op(x))
+    x_np = np.random.uniform(size=x_np_shape).astype(dtype)
+    res_np = np_op(x_np)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(x_np)
+        tvm.testing.assert_allclose(result.asnumpy(), res_np)
+
+def test_any_elemwise():
+    verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt)
+    verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative)
+    verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp)
 
 def test_any_broadcast_fail():
     # Test broadcast with incompatible values at runtime
@@ -107,12 +123,14 @@ def test_any_full():
 def test_any_concat():
     x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
     y = relay.var('y', shape=(1, 2), dtype="float32")
-    z = relay.op.concatenate([x, y], axis=0)
+    xx = x - relay.expr.const(3.0)
+    yy = y * relay.expr.const(5.0)
+    z = relay.op.concatenate([xx, yy], axis=0)
     mod = relay.module.Module()
     mod["main"] = relay.Function([x, y], z)
     x_np = np.random.uniform(size=(3, 2)).astype('float32')
     y_np = np.random.uniform(size=(1, 2)).astype('float32')
-    ref = np.concatenate([x_np, y_np], axis=0)
+    ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
     for kind in ["debug", "vm"]:
         ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
         result = ex.evaluate()(x_np, y_np)
@@ -417,6 +435,24 @@ def test_any_global_pool2d():
     verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
                       "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4))
 
+def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.split(data, indices_or_sections, axis)
+    mod["main"] = relay.Function([data], y.astuple())
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        for ret, ref_ret in zip(result, ref_out_shape):
+            assert ret.asnumpy().shape == ref_ret, \
+                "Shape mismatch: expect %s but got %s." % (str(ref_ret), str(ret.asnumpy().shape))
+
+def test_any_split():
+    verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)])
+    verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
+
 def test_any_batch_flatten():
     mod = relay.Module()
     dtype = "float32"
@@ -601,11 +637,13 @@ def test_recursive_concat_with_wrong_annotation():
 if __name__ == "__main__":
     test_any_full()
     test_any_broadcast()
+    test_any_elemwise()
     test_any_broadcast_fail()
     test_any_concat()
     test_any_reshape()
     test_any_take()
     test_any_tile()
+    test_any_split()
     test_any_shape_of()
     test_any_reduce()
     test_any_layout_transform()