"""Backend compiler related feature registration"""
from __future__ import absolute_import
+from tvm.te.hybrid import script
+from tvm.runtime import convert
+
import topi
from topi.util import get_const_tuple
from .. import op as reg
reg.register_injective_schedule("image.crop_and_resize")
+@script
+def _crop_and_resize_func(image_shape, boxes_shape, crop_size,
+ height_axis, width_axis, channel_axis):
+ out = output_tensor((4,), "int64")
+ out[0] = boxes_shape[0]
+ out[height_axis] = int64(crop_size[0])
+ out[width_axis] = int64(crop_size[1])
+ out[channel_axis] = image_shape[channel_axis]
+ return out
+
+@reg.register_shape_func("image.crop_and_resize", False)
+def crop_and_resize_func(attrs, inputs, _):
+ """
+ Shape function for crop_and_resize op.
+ """
+ layout = attrs.layout
+ height_axis = width_axis = channel_axis = 1
+ for i, letter in enumerate(layout):
+ if letter == "H":
+ height_axis = i
+ if letter == "W":
+ width_axis = i
+ if letter == "C":
+ channel_axis = i
+ crop_size = get_const_tuple(attrs.crop_size)
+ return [_crop_and_resize_func(inputs[0], inputs[1], convert(crop_size),
+ convert(height_axis), convert(width_axis), convert(channel_axis))]
+
# dilation2d
reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy)
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
+ layout, static_boxes, static_box_indices_shape, ref_out_shape):
+ mod = tvm.IRModule()
+ dtype = "float32"
+ indices_dtype = "int32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ boxes = relay.var('boxes', shape=boxes_shape, dtype=dtype)
+ box_indices = relay.var('box_indices', shape=box_indices_shape, dtype=indices_dtype)
+ y = relay.image.crop_and_resize(data, boxes, box_indices, crop_size, layout)
+ mod["main"] = relay.Function([data, boxes, box_indices], y)
+ data_np = np.random.uniform(size=data_shape).astype(dtype)
+ boxes_np = np.random.uniform(size=static_boxes).astype(dtype)
+ box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np, boxes_np, box_indices_np)
+ tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape)
+
+def test_any_crop_and_resize():
+ verify_any_crop_and_resize(
+ data_shape=(1, 234, 234, 256),
+ boxes_shape=(relay.Any(), 4),
+ box_indices_shape=(relay.Any(),),
+ crop_size=(14, 14),
+ layout='NHWC',
+ static_boxes=(128, 4),
+ static_box_indices_shape=(128,),
+ ref_out_shape=(128, 14, 14, 256))
+ verify_any_crop_and_resize(
+ data_shape=(1, 256, 234, 234),
+ boxes_shape=(relay.Any(), 4),
+ box_indices_shape=(relay.Any(),),
+ crop_size=(14, 14),
+ layout='NCHW',
+ static_boxes=(128, 4),
+ static_box_indices_shape=(128,),
+ ref_out_shape=(128, 256, 14, 14)
+ )
+
+def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shape):
+ mod = tvm.IRModule()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.nn.mirror_pad(data, pad_width)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape)
+
+def test_any_mirror_pad():
+ verify_any_mirror_pad(
+ data_shape=(1, 256, 232, 232),
+ pad_width=((0, 0), (0, 0), (1, 1), (1, 1)),
+ static_data_shape=(1, 256, 232, 232),
+ ref_out_shape=(1, 256, 234, 234))
+
if __name__ == "__main__":
test_any_full()
test_any_full_like()
test_recursive_concat_with_wrong_annotation()
test_tuple_get_item()
test_mixed_input_type()
+ test_any_crop_and_resize()
+ test_any_mirror_pad()
+