Register Shape Func for Some Operators to Handle Dynamic Shapes (#5955)
authorSiyuan Li <siyuanli.s.c@gmail.com>
Thu, 23 Jul 2020 20:23:49 +0000 (04:23 +0800)
committerGitHub <noreply@github.com>
Thu, 23 Jul 2020 20:23:49 +0000 (13:23 -0700)
* Register Shape Func for Floor Operator

Register the shape function for `floor` operator. Otherwise, a bug will happen when input of floor is any.

* Register shape func for log

* add shape function for crop_and_size

* change import location

* add mirror_pad shape function

* add test cases for crop_and_resize and mirror_pad shape funcs

* support different layout

* fix pylint error

* fix pylint error

* add test for nchw layout

* block nchw test

* test for nchw

* use tvm.testing.assert_allclose instead

Co-authored-by: lisiyuan <lisiyuan@nucflow>
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/image/_image.py
python/tvm/relay/op/nn/_nn.py
tests/python/relay/test_any.py

index feeec1f..2ca2a01 100644 (file)
@@ -239,3 +239,5 @@ register_shape_func("tan", False, elemwise_shape_func)
 register_shape_func("fast_exp", False, elemwise_shape_func)
 register_shape_func("fast_tanh", False, elemwise_shape_func)
 register_shape_func("fast_erf", False, elemwise_shape_func)
+register_shape_func("floor", False, elemwise_shape_func)
+register_shape_func("log", False, elemwise_shape_func)
index bcb110f..795844f 100644 (file)
@@ -18,6 +18,9 @@
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
+from tvm.te.hybrid import script
+from tvm.runtime import convert
+
 import topi
 from topi.util import get_const_tuple
 from .. import op as reg
@@ -64,6 +67,34 @@ def compute_crop_and_resize(attrs, inputs, out_type):
 
 reg.register_injective_schedule("image.crop_and_resize")
 
+@script
+def _crop_and_resize_func(image_shape, boxes_shape, crop_size,
+                          height_axis, width_axis, channel_axis):
+    out = output_tensor((4,), "int64")
+    out[0] = boxes_shape[0]
+    out[height_axis] = int64(crop_size[0])
+    out[width_axis] = int64(crop_size[1])
+    out[channel_axis] = image_shape[channel_axis]
+    return out
+
+@reg.register_shape_func("image.crop_and_resize", False)
+def crop_and_resize_func(attrs, inputs, _):
+    """
+    Shape function for crop_and_resize op.
+    """
+    layout = attrs.layout
+    height_axis = width_axis = channel_axis = 1
+    for i, letter in enumerate(layout):
+        if letter == "H":
+            height_axis = i
+        if letter == "W":
+            width_axis = i
+        if letter == "C":
+            channel_axis = i
+    crop_size = get_const_tuple(attrs.crop_size)
+    return [_crop_and_resize_func(inputs[0], inputs[1], convert(crop_size),
+                                  convert(height_axis), convert(width_axis), convert(channel_axis))]
+
 
 # dilation2d
 reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy)
index 0757e96..cea592a 100644 (file)
@@ -441,6 +441,19 @@ def compute_mirror_pad(attrs, inputs, out_dtype):
 reg.register_broadcast_schedule("nn.mirror_pad")
 
 
+@script
+def _mirror_pad_func(data_shape, pad_width):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    for i in const_range(data_shape.shape[0]):
+        out[i] = data_shape[i] + int64(pad_width[i][0]) + int64(pad_width[i][1])
+    return out
+
+@reg.register_shape_func("nn.mirror_pad", False)
+def mirror_pad_func(attrs, inputs, _):
+    pad_width_tuple = [get_const_tuple(p) for p in attrs.pad_width]
+    return [_mirror_pad_func(inputs[0], convert(pad_width_tuple))]
+
+
 # conv2d_winograd related operators
 reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
                       strategy.conv2d_winograd_without_weight_transfrom_strategy)
index 6d940a5..6810d0b 100644 (file)
@@ -816,6 +816,64 @@ def test_mixed_input_type():
         assert result.asnumpy().shape == ref_out_shape, \
             "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
 
+def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size, 
+                               layout, static_boxes, static_box_indices_shape, ref_out_shape):
+    mod = tvm.IRModule()
+    dtype = "float32"
+    indices_dtype = "int32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    boxes = relay.var('boxes', shape=boxes_shape, dtype=dtype)
+    box_indices = relay.var('box_indices', shape=box_indices_shape, dtype=indices_dtype)
+    y = relay.image.crop_and_resize(data, boxes, box_indices, crop_size, layout)
+    mod["main"] = relay.Function([data, boxes, box_indices], y)
+    data_np = np.random.uniform(size=data_shape).astype(dtype)
+    boxes_np = np.random.uniform(size=static_boxes).astype(dtype)
+    box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype)    
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np, boxes_np, box_indices_np)
+        tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape)
+
+def test_any_crop_and_resize():
+    verify_any_crop_and_resize(
+        data_shape=(1, 234, 234, 256), 
+        boxes_shape=(relay.Any(), 4), 
+        box_indices_shape=(relay.Any(),),
+        crop_size=(14, 14),
+        layout='NHWC',
+        static_boxes=(128, 4),
+        static_box_indices_shape=(128,),
+        ref_out_shape=(128, 14, 14, 256))
+    verify_any_crop_and_resize(
+        data_shape=(1, 256, 234, 234), 
+        boxes_shape=(relay.Any(), 4), 
+        box_indices_shape=(relay.Any(),),
+        crop_size=(14, 14),
+        layout='NCHW',
+        static_boxes=(128, 4),
+        static_box_indices_shape=(128,),
+        ref_out_shape=(128, 256, 14, 14)
+        )
+
+def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shape):
+    mod = tvm.IRModule()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.nn.mirror_pad(data, pad_width)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape)
+
+def test_any_mirror_pad():
+    verify_any_mirror_pad(
+        data_shape=(1, 256, 232, 232),
+        pad_width=((0, 0), (0, 0), (1, 1), (1, 1)),
+        static_data_shape=(1, 256, 232, 232),
+        ref_out_shape=(1, 256, 234, 234))
+
 if __name__ == "__main__":
     test_any_full()
     test_any_full_like()
@@ -850,3 +908,6 @@ if __name__ == "__main__":
     test_recursive_concat_with_wrong_annotation()
     test_tuple_get_item()
     test_mixed_input_type()
+    test_any_crop_and_resize()
+    test_any_mirror_pad()
+