From fe761964029b1094b00925fe577bc8bf5f9f2d66 Mon Sep 17 00:00:00 2001 From: Siyuan Li Date: Fri, 24 Jul 2020 04:23:49 +0800 Subject: [PATCH] Register Shape Func for Some Operators to Handle Dynamic Shapes (#5955) * Register Shape Func for Floor Operator Register the shape function for `floor` operator. Otherwise, a bug will happen when input of floor is any. * Register shape func for log * add shape function for crop_and_size * change import location * add mirror_pad shape function * add test cases for crop_and_resize and mirror_pad shape funcs * support different layout * fix pylint error * fix pylint error * add test for nchw layout * block nchw test * test for nchw * use tvm.testing.assert_allclose instead Co-authored-by: lisiyuan --- python/tvm/relay/op/_tensor.py | 2 ++ python/tvm/relay/op/image/_image.py | 31 +++++++++++++++++++ python/tvm/relay/op/nn/_nn.py | 13 ++++++++ tests/python/relay/test_any.py | 61 +++++++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index feeec1f..2ca2a01 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -239,3 +239,5 @@ register_shape_func("tan", False, elemwise_shape_func) register_shape_func("fast_exp", False, elemwise_shape_func) register_shape_func("fast_tanh", False, elemwise_shape_func) register_shape_func("fast_erf", False, elemwise_shape_func) +register_shape_func("floor", False, elemwise_shape_func) +register_shape_func("log", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index bcb110f..795844f 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -18,6 +18,9 @@ """Backend compiler related feature registration""" from __future__ import absolute_import +from tvm.te.hybrid import script +from tvm.runtime import convert + import topi from topi.util import get_const_tuple from .. import op as reg @@ -64,6 +67,34 @@ def compute_crop_and_resize(attrs, inputs, out_type): reg.register_injective_schedule("image.crop_and_resize") +@script +def _crop_and_resize_func(image_shape, boxes_shape, crop_size, + height_axis, width_axis, channel_axis): + out = output_tensor((4,), "int64") + out[0] = boxes_shape[0] + out[height_axis] = int64(crop_size[0]) + out[width_axis] = int64(crop_size[1]) + out[channel_axis] = image_shape[channel_axis] + return out + +@reg.register_shape_func("image.crop_and_resize", False) +def crop_and_resize_func(attrs, inputs, _): + """ + Shape function for crop_and_resize op. + """ + layout = attrs.layout + height_axis = width_axis = channel_axis = 1 + for i, letter in enumerate(layout): + if letter == "H": + height_axis = i + if letter == "W": + width_axis = i + if letter == "C": + channel_axis = i + crop_size = get_const_tuple(attrs.crop_size) + return [_crop_and_resize_func(inputs[0], inputs[1], convert(crop_size), + convert(height_axis), convert(width_axis), convert(channel_axis))] + # dilation2d reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 0757e96..cea592a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -441,6 +441,19 @@ def compute_mirror_pad(attrs, inputs, out_dtype): reg.register_broadcast_schedule("nn.mirror_pad") +@script +def _mirror_pad_func(data_shape, pad_width): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(data_shape.shape[0]): + out[i] = data_shape[i] + int64(pad_width[i][0]) + int64(pad_width[i][1]) + return out + +@reg.register_shape_func("nn.mirror_pad", False) +def mirror_pad_func(attrs, inputs, _): + pad_width_tuple = [get_const_tuple(p) for p in attrs.pad_width] + return [_mirror_pad_func(inputs[0], convert(pad_width_tuple))] + + # conv2d_winograd related operators reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform", strategy.conv2d_winograd_without_weight_transfrom_strategy) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6d940a5..6810d0b 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -816,6 +816,64 @@ def test_mixed_input_type(): assert result.asnumpy().shape == ref_out_shape, \ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) +def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size, + layout, static_boxes, static_box_indices_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + indices_dtype = "int32" + data = relay.var('data', shape=data_shape, dtype=dtype) + boxes = relay.var('boxes', shape=boxes_shape, dtype=dtype) + box_indices = relay.var('box_indices', shape=box_indices_shape, dtype=indices_dtype) + y = relay.image.crop_and_resize(data, boxes, box_indices, crop_size, layout) + mod["main"] = relay.Function([data, boxes, box_indices], y) + data_np = np.random.uniform(size=data_shape).astype(dtype) + boxes_np = np.random.uniform(size=static_boxes).astype(dtype) + box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np, boxes_np, box_indices_np) + tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape) + +def test_any_crop_and_resize(): + verify_any_crop_and_resize( + data_shape=(1, 234, 234, 256), + boxes_shape=(relay.Any(), 4), + box_indices_shape=(relay.Any(),), + crop_size=(14, 14), + layout='NHWC', + static_boxes=(128, 4), + static_box_indices_shape=(128,), + ref_out_shape=(128, 14, 14, 256)) + verify_any_crop_and_resize( + data_shape=(1, 256, 234, 234), + boxes_shape=(relay.Any(), 4), + box_indices_shape=(relay.Any(),), + crop_size=(14, 14), + layout='NCHW', + static_boxes=(128, 4), + static_box_indices_shape=(128,), + ref_out_shape=(128, 256, 14, 14) + ) + +def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.nn.mirror_pad(data, pad_width) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape) + +def test_any_mirror_pad(): + verify_any_mirror_pad( + data_shape=(1, 256, 232, 232), + pad_width=((0, 0), (0, 0), (1, 1), (1, 1)), + static_data_shape=(1, 256, 232, 232), + ref_out_shape=(1, 256, 234, 234)) + if __name__ == "__main__": test_any_full() test_any_full_like() @@ -850,3 +908,6 @@ if __name__ == "__main__": test_recursive_concat_with_wrong_annotation() test_tuple_get_item() test_mixed_input_type() + test_any_crop_and_resize() + test_any_mirror_pad() + -- 2.7.4