From: Siju Date: Wed, 28 Nov 2018 01:49:02 +0000 (+0530) Subject: [Relay]resize op compute and schedule (#2172) X-Git-Tag: upstream/0.7.0~3030 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=55599b937ba69ef486241e599a6c44ae7c9ceb0b;p=platform%2Fupstream%2Ftvm.git [Relay]resize op compute and schedule (#2172) --- diff --git a/python/tvm/relay/op/image/__init__.py b/python/tvm/relay/op/image/__init__.py index 9d1415b1d..5fa5c0157 100644 --- a/python/tvm/relay/op/image/__init__.py +++ b/python/tvm/relay/op/image/__init__.py @@ -2,3 +2,4 @@ """Image network related operators.""" from __future__ import absolute_import as _abs from .image import * +from ._image import * diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py new file mode 100644 index 000000000..e44748372 --- /dev/null +++ b/python/tvm/relay/op/image/_image.py @@ -0,0 +1,7 @@ +#pylint: disable=invalid-name, unused-argument +"""Backend compiler related feature registration""" +from __future__ import absolute_import +from ..op import register_schedule, schedule_injective + +# resize +register_schedule("image.resize", schedule_injective) diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index bfa2ea4cd..e6efcb8ce 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -5,7 +5,10 @@ */ #include #include +#include +#include #include "../layout.h" +#include "../op_common.h" namespace tvm { namespace relay { @@ -40,6 +43,29 @@ bool ResizeRel(const Array& types, return true; } +Array ResizeCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + CHECK(param->layout == "NCHW" || param->layout == "NHWC"); + const auto* out_ttype = out_type.as(); + CHECK(out_ttype != nullptr); + Array oshape; + if (param->layout == "NCHW") { + oshape.push_back(out_ttype->shape[2]); + oshape.push_back(out_ttype->shape[3]); + } else if (param->layout == "NHWC") { + oshape.push_back(out_ttype->shape[1]); + oshape.push_back(out_ttype->shape[2]); + } + return Array{ topi::image::resize(inputs[0], + oshape, + param->layout, + param->align_corners, + param->method) }; +} // Positional relay function to create image operator // used by frontend FFI. @@ -82,7 +108,9 @@ RELAY_REGISTER_OP("image.resize") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) -.add_type_rel("Resize", ResizeRel); +.add_type_rel("Resize", ResizeRel) +.set_attr("FTVMCompute", ResizeCompute) +.set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 0bd7a4816..77e3f005d 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -1,7 +1,10 @@ """ Support level5 operator test cases. """ +import numpy as np import tvm from tvm import relay +from tvm.relay.testing import ctx_list +import topi.testing def test_resize_infer_type(): n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") @@ -17,6 +20,33 @@ def test_resize_infer_type(): zz = relay.ir_pass.infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") +def test_resize(): + def verify_resize(dshape, scale, method, layout): + if layout == "NHWC": + size = (dshape[1] * scale, dshape[2] * scale) + else: + size = (dshape[2] * scale, dshape[3] * scale) + + x_data = np.random.uniform(size=dshape).astype("float32") + if method == "BILINEAR": + ref_res = topi.testing.bilinear_resize_python(x_data, size, layout) + else: + ref_res = topi.testing.upsampling_python(x_data, scale, layout) + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.image.resize(x, size, layout, method, False) + assert "size=" in z.astext() + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") + func = relay.Function([x], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + for method in ["BILINEAR", "NEAREST_NEIGHBOR"]: + for layout in ["NHWC", "NCHW"]: + verify_resize((1, 4, 4, 4), 2, method, layout) def test_multibox_prior(): sizes = (0.3, 1.5, 0.7) @@ -74,5 +104,6 @@ def test_nms(): if __name__ == "__main__": test_resize_infer_type() + test_resize() test_multibox_prior() test_nms()