From 5c1bf981130bd942f0617d3c5a5d3ff5c76f74a4 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 1 Jul 2020 08:04:15 -0700 Subject: [PATCH] Add MXnNet parser for box_decode (#5967) --- python/tvm/relay/frontend/mxnet.py | 37 +++++++++++++++++++++++++++++ tests/python/frontend/mxnet/test_forward.py | 23 ++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 321b145..135756b 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -978,6 +978,42 @@ def _mx_box_nms(inputs, attrs): return nms_out +def _mx_box_decode(inputs, attrs): + std0 = relay.const(attrs.get_float('std0', 1), "float32") + std1 = relay.const(attrs.get_float('std1', 1), "float32") + std2 = relay.const(attrs.get_float('std2', 1), "float32") + std3 = relay.const(attrs.get_float('std3', 1), "float32") + clip = attrs.get_float('clip', -1) + in_format = attrs.get_str('format', 'corner') + + anchors = inputs[1] # (1, N, 4) encoded in corner or center + a = _op.split(anchors, indices_or_sections=4, axis=-1) + # Convert to format "center". + if in_format == "corner": + a_width = a[2] - a[0] + a_height = a[3] - a[1] + a_x = a[0] + a_width * relay.const(0.5, "float32") + a_y = a[1] + a_height * relay.const(0.5, "float32") + else: + a_x, a_y, a_width, a_height = a + data = inputs[0] # (B, N, 4) predicted bbox offset + p = _op.split(data, indices_or_sections=4, axis=-1) + ox = p[0] * std0 * a_width + a_x + oy = p[1] * std1 * a_height + a_y + dw = p[2] * std2 + dh = p[3] * std3 + if clip > 0: + clip = relay.const(clip, "float32") + dw = _op.minimum(dw, clip) + dh = _op.minimum(dh, clip) + dw = _op.exp(dw) + dh = _op.exp(dh) + ow = dw * a_width * relay.const(0.5, "float32") + oh = dh * a_height * relay.const(0.5, "float32") + out = _op.concatenate([ox - ow, oy - oh, ox + ow, oy + oh], axis=-1) + return out + + def _mx_l2_normalize(inputs, attrs): new_attrs = {} mode = attrs.get_str('mode', 'instance') @@ -2220,6 +2256,7 @@ _convert_map = { "_contrib_Proposal" : _mx_proposal, "_contrib_MultiProposal" : _mx_proposal, "_contrib_box_nms" : _mx_box_nms, + "_contrib_box_decode" : _mx_box_decode, "_contrib_DeformableConvolution" : _mx_deformable_convolution, "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling, "GridGenerator" : _mx_grid_generator, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index ae5ed45..4d8b1e9 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1306,6 +1306,28 @@ def test_forward_interleaved_matmul_selfatt_valatt(): verify(3, 10, 6, 8) +def test_forward_box_decode(): + def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corner"): + dtype = "float32" + data = np.random.uniform(low=-2, high=2, size=data_shape).astype(dtype) + anchors = np.random.uniform(low=-2, high=2, size=anchor_shape).astype(dtype) + ref_res = mx.nd.contrib.box_decode(mx.nd.array(data), mx.nd.array(anchors), stds[0], stds[1], stds[2], stds[3], clip, in_format) + mx_sym = mx.sym.contrib.box_decode(mx.sym.var("data"), mx.sym.var("anchors"), stds[0], stds[1], stds[2], stds[3], clip, in_format) + shape_dict = {"data": data_shape, "anchors": anchor_shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data, anchors) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) + + verify((1, 10, 4), (1, 10, 4)) + verify((4, 10, 4), (1, 10, 4)) + verify((1, 10, 4), (1, 10, 4), stds=[2, 3, 0.5, 1.5]) + verify((1, 10, 4), (1, 10, 4), clip=1) + verify((1, 10, 4), (1, 10, 4), in_format="center") + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -1379,3 +1401,4 @@ if __name__ == '__main__': test_forward_arange_like() test_forward_interleaved_matmul_selfatt_qk() test_forward_interleaved_matmul_selfatt_valatt() + test_forward_box_decode() -- 2.7.4