From: masahi Date: Tue, 28 May 2019 22:20:58 +0000 (+0900) Subject: [TOPI] Fix resize nearest with fractional scaling (#3244) X-Git-Tag: upstream/0.7.0~2376 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a8275bdbfb8cff2657bb94e867a97b38585a9b8a;p=platform%2Fupstream%2Ftvm.git [TOPI] Fix resize nearest with fractional scaling (#3244) --- diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index b25feb7..3c56515 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -305,7 +305,7 @@ def test_upsampling_nearest_neighbor(): data = tvm.nd.array(a_np) m.run(x=data) out = m.get_output(0, tvm.nd.empty(oshape, dtype)) - b_np = topi.testing.upsampling_python(a_np, scale, "NCHW") + b_np = topi.testing.upsampling_python(a_np, (scale, scale), "NCHW") tvm.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5) def test_upsampling_bilinear(): diff --git a/nnvm/tests/python/frontend/coreml/test_forward.py b/nnvm/tests/python/frontend/coreml/test_forward.py index 679afe4..7a9f294 100644 --- a/nnvm/tests/python/frontend/coreml/test_forward.py +++ b/nnvm/tests/python/frontend/coreml/test_forward.py @@ -195,7 +195,7 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): a_np = np.full(input_dim, 1, dtype=dtype) if mode == 'NN': - b_np = topi.testing.upsampling_python(a_np, scale) + b_np = topi.testing.upsampling_python(a_np, (scale, scale)) else: new_h = input_dim[2] * scale new_w = input_dim[3] * scale diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index 941a275..3365b0f 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -405,7 +405,7 @@ def _test_upsample_nearest(): y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.upsampling_python(in_array, scale, "NCHW") + out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW") graph = helper.make_graph([y], 'upsample_nearest_test', diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index da78e96..0b6f91b 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -179,7 +179,7 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): a_np = np.full(input_dim, 1, dtype=dtype) if mode == 'NN': - b_np = topi.testing.upsampling_python(a_np, scale) + b_np = topi.testing.upsampling_python(a_np, (scale, scale)) else: new_h = input_dim[2] * scale new_w = input_dim[3] * scale diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 77f045a..095f1fe 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -417,7 +417,7 @@ def _test_upsample_nearest(): y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.upsampling_python(in_array, scale, "NCHW") + out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW") graph = helper.make_graph([y], 'upsample_nearest_test', diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index a535045..c8f5b1d 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -485,7 +485,7 @@ def _test_upsampling(layout, method): func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) if method == "NEAREST_NEIGHBOR": - ref = topi.testing.upsampling_python(data, scale, layout) + ref = topi.testing.upsampling_python(data, (scale, scale), layout) else: ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout) for target, ctx in ctx_list(): diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index e6d99c7..21b227f 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -48,7 +48,7 @@ def test_resize(): if method == "BILINEAR": ref_res = topi.testing.bilinear_resize_python(x_data, size, layout) else: - ref_res = topi.testing.upsampling_python(x_data, scale, layout) + ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout) x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.image.resize(x, size, layout, method, False) assert "size=" in z.astext() diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index fb577a8..287ff94 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -101,15 +101,12 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input, out_shape.push_back(shape[1]); out_shape.push_back(input->shape[3]); - Expr h_ratio = shape[0] / input->shape[1]; - Expr w_ratio = shape[1] / input->shape[2]; - return compute( out_shape, [&](const Array& indices) { Array idx; idx.push_back(indices[0]); - idx.push_back(indices[1] / h_ratio); - idx.push_back(indices[2] / w_ratio); + idx.push_back(indices[1] * input->shape[1] / shape[0]); + idx.push_back(indices[2] * input->shape[2] / shape[1]); idx.push_back(indices[3]); return input(idx); @@ -138,16 +135,13 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input, out_shape.push_back(shape[0]); out_shape.push_back(shape[1]); - Expr h_ratio = shape[0] / input->shape[2]; - Expr w_ratio = shape[1] / input->shape[3]; - return compute( out_shape, [&](const Array& indices) { Array idx; idx.push_back(indices[0]); idx.push_back(indices[1]); - idx.push_back(indices[2] / h_ratio); - idx.push_back(indices[3] / w_ratio); + idx.push_back(indices[2] * input->shape[2] / shape[0]); + idx.push_back(indices[3] * input->shape[3] / shape[1]); return input(idx); }, name, tag); @@ -176,16 +170,13 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input, out_shape.push_back(shape[1]); out_shape.push_back(input->shape[4]); - Expr h_ratio = shape[0] / input->shape[2]; - Expr w_ratio = shape[1] / input->shape[3]; - return compute( out_shape, [&](const Array& indices) { Array idx; idx.push_back(indices[0]); idx.push_back(indices[1]); - idx.push_back(indices[2] / h_ratio); - idx.push_back(indices[3] / w_ratio); + idx.push_back(indices[2] * input->shape[2] / shape[0]); + idx.push_back(indices[3] * input->shape[3] / shape[1]); idx.push_back(indices[4]); return input(idx); diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 14c7c05..7926df2 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -53,5 +53,4 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale)) else: raise ValueError("not support this layout {} yet".format(layout)) - return topi.cpp.nn.upsampling(data, out_shape, layout, method) diff --git a/topi/python/topi/testing/upsampling_python.py b/topi/python/topi/testing/upsampling_python.py index 8ee9640..167fdfc 100644 --- a/topi/python/topi/testing/upsampling_python.py +++ b/topi/python/topi/testing/upsampling_python.py @@ -16,25 +16,35 @@ # under the License. # pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals """Upsampling in python""" +import math import numpy as np def upsample_nearest(arr, scale): """ Populate the array by scale factor""" - return arr.repeat(scale, axis=0).repeat(scale, axis=1) + h, w = arr.shape + out_h = math.floor(h * scale[0]) + out_w = math.floor(w * scale[1]) + out = np.empty((out_h, out_w)) + for y in range(out_h): + for x in range(out_w): + in_y = math.floor(y / scale[0]) + in_x = math.floor(x / scale[1]) + out[y, x] = arr[in_y, in_x] + return out def upsampling_python(data, scale, layout='NCHW'): """ Python version of scaling using nearest neighbour """ ishape = data.shape if layout == 'NCHW': - oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale) + oshape = (ishape[0], ishape[1], math.floor(ishape[2]*scale[0]), math.floor(ishape[3]*scale[1])) output_np = np.zeros(oshape, dtype=data.dtype) for b in range(oshape[0]): for c in range(oshape[1]): output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) return output_np if layout == 'NHWC': - oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3]) + oshape = (ishape[0], math.floor(ishape[1]*scale[0]), math.floor(ishape[1]*scale[1]), ishape[3]) output_np = np.zeros(oshape, dtype=data.dtype) for b in range(oshape[0]): for c in range(oshape[3]): diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_resize.py index 26a5e35..8277886 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_resize.py @@ -23,8 +23,7 @@ import math from common import get_all_backend -def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False): - +def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False, method="BILINEAR"): if layout == 'NCHW': A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32') dtype = A.dtype @@ -39,9 +38,14 @@ def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, ou raise NotImplementedError( 'Layout not supported {} '.format(layout)) - B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners) + B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method) - b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) + if method == "BILINEAR": + b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) + else: + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) def check_device(device): ctx = tvm.context(device, 0) @@ -61,15 +65,19 @@ def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, ou for device in get_all_backend(): check_device(device) + def test_resize(): # Scale NCHW - verify_bilinear_scale(4, 16, 32, 32, 50, 50, 'NCHW') + verify_resize(4, 16, 32, 32, 50, 50, 'NCHW') # Scale NCHW + Align Corners - verify_bilinear_scale(6, 32, 64, 64, 20, 20, 'NCHW', True) + verify_resize(6, 32, 64, 64, 20, 20, 'NCHW', True) # Scale NHWC - verify_bilinear_scale(4, 16, 32, 32, 50, 50, "NHWC") + verify_resize(4, 16, 32, 32, 50, 50, "NHWC") # Scale NHWC + Align Corners - verify_bilinear_scale(6, 32, 64, 64, 20, 20, "NHWC", True) + verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True) + # Nearest + Fractional + verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="NEAREST_NEIGHBOR") + verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="NEAREST_NEIGHBOR") if __name__ == "__main__": test_resize() diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index 0838f02..ddfb002 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -46,7 +46,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH out_size = (in_height*scale, in_width*scale) b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout) else: - b_np = topi.testing.upsampling_python(a_np, scale, layout) + b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout) def check_device(device): ctx = tvm.context(device, 0)