From 7264cb6a5996af874819f03cf27aca3b5ad48814 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 21 Aug 2019 20:39:09 -0700 Subject: [PATCH] Changed topi cc resize to python implementation with new features. (#3788) --- include/tvm/relay/attrs/image.h | 13 +- include/tvm/relay/attrs/nn.h | 10 +- nnvm/tests/python/compiler/test_top_level2.py | 4 +- nnvm/tests/python/frontend/onnx/test_forward.py | 4 +- python/tvm/relay/frontend/coreml.py | 2 +- python/tvm/relay/frontend/keras.py | 11 +- python/tvm/relay/frontend/onnx.py | 6 +- python/tvm/relay/frontend/tensorflow.py | 6 +- python/tvm/relay/frontend/tflite.py | 6 +- python/tvm/relay/op/image/_image.py | 17 +- python/tvm/relay/op/image/image.py | 14 +- python/tvm/relay/op/nn/_nn.py | 7 + python/tvm/relay/op/nn/nn.py | 12 +- src/relay/op/image/resize.cc | 38 ++--- src/relay/op/nn/upsampling.cc | 39 +---- tests/python/frontend/keras/test_forward.py | 2 +- tests/python/frontend/tensorflow/test_forward.py | 2 +- tests/python/relay/test_op_level2.py | 22 +-- tests/python/relay/test_op_level5.py | 8 +- topi/include/topi/image/resize.h | 1 - topi/python/topi/image/resize.py | 175 ++++++++++++++++++++- topi/python/topi/nn/upsampling.py | 7 +- topi/python/topi/testing/bilinear_resize_python.py | 2 +- topi/tests/python/test_topi_resize.py | 8 +- topi/tests/python/test_topi_upsampling.py | 24 +-- 25 files changed, 296 insertions(+), 144 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index e74c0f5..dd3a0aa 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -37,6 +37,7 @@ struct ResizeAttrs : public tvm::AttrsNode { std::string layout; std::string method; bool align_corners; + DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { TVM_ATTR_FIELD(size).set_default(NullValue >()) @@ -46,12 +47,16 @@ struct ResizeAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Resize is applied on the 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("BILINEAR") + TVM_ATTR_FIELD(method).set_default("bilinear") .describe("Specify the mode to use for scaling." - "NEAREST_NEIGHBOR - Nearest Neighbor" - "BILINEAR - Bilinear Interpolation"); - TVM_ATTR_FIELD(align_corners).set_default(false) + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation" + "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(align_corners).set_default(true) .describe("Should be true to preserve the values at the corner pixels"); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type."); } }; diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index cbbcc2f..fd3f4c8 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -381,6 +381,7 @@ struct UpSamplingAttrs : public tvm::AttrsNode { int scale; std::string layout; std::string method; + bool align_corners; TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { TVM_ATTR_FIELD(scale) @@ -390,10 +391,13 @@ struct UpSamplingAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Upsampling is applied on the 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("NEAREST_NEIGHBOR") + TVM_ATTR_FIELD(method).set_default("nearest_neighbor") .describe("Specify the mode to use for scaling." - "NEAREST_NEIGHBOR - Nearest Neighbor" - "BILINEAR - Bilinear Interpolation"); + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation" + "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(align_corners).set_default(false) + .describe("Should be true to preserve the values at the corner pixels"); } }; diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index 3c56515..b558428 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -324,12 +324,12 @@ def test_upsampling_bilinear(): data = tvm.nd.array(a_np) m.run(x=data) out = m.get_output(0, tvm.nd.empty(oshape, dtype)) - b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NCHW") + b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NCHW", align_corners=False) tvm.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5) def test_resize_bilinear(): x = sym.Variable("x") - y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC") + y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC", align_corners=True) dtype = "float32" dshape = (1, 32, 32, 4) oshape = (1, 60, 60, 4) diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index 3365b0f..8cb6876 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -425,7 +425,7 @@ def _test_upsample_bilinear(): y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW") + out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW", align_corners=False) graph = helper.make_graph([y], 'upsample_bilinear_test', @@ -445,7 +445,7 @@ def _test_upsample_bilinear_opset9(): y = helper.make_node("Upsample", ['in','scales'], ['out'], mode='linear') scales=[1.0, 1.0, 2.0, 2.0] in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW") + out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW", align_corners=False) ref_array = np.array(scales) ref_node = helper.make_node('Constant', diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index ba5438e..8b158ca 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -312,7 +312,7 @@ def _UpsampleLayerParams(op, inexpr, etab): if op.scalingFactor[0] != op.scalingFactor[1]: raise tvm.error.OpAttributeUnimplemented( 'Upsample height and width must be equal.') - interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR' + interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear' return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 845543f..4d3b976 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -358,29 +358,30 @@ def _convert_pooling(inexpr, keras_layer, etab): def _convert_upsample(inexpr, keras_layer, _): _check_data_format(keras_layer) upsample_type = type(keras_layer).__name__ + params = {'layout': 'NHWC'} if upsample_type == 'UpSampling1D': h = keras_layer.size - params = {'scale': h} + params['scale'] = h elif upsample_type == 'UpSampling2D': h, w = keras_layer.size if h != w: raise tvm.error.OpAttributeInvalid( 'Height must equal width for operator Upsample.') - params = {'scale': h} + params['scale'] = h if hasattr(keras_layer, 'interpolation'): interpolation = keras_layer.interpolation if interpolation == 'nearest': - params['method'] = 'NEAREST_NEIGHBOR' + params['method'] = 'nearest_neighbor' else: - params['method'] = 'BILINEAR' + params['method'] = 'bilinear' elif upsample_type == 'UpSampling3D': h, w, d = keras_layer.size if h != w or w != d: raise tvm.error.OpAttributeInvalid( 'Height, width, and depth must all be equal for operator Upsample.') - params = {'scale': h} + params['scale'] = h else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend Keras.'.format(upsample_type)) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b7d668b..11d73e2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -559,13 +559,13 @@ class Upsample(OnnxOpConverter): assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3] mode = attr.get('mode') if mode == b'nearest': - method = "NEAREST_NEIGHBOR" + method = "nearest_neighbor" elif mode == b'linear': - method = "BILINEAR" + method = "bilinear" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) - attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'} + attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW', 'align_corners':True} return AttrCvt('upsampling')(inputs, attr) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index bbc0fec..54724b5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -358,7 +358,7 @@ def _crop_and_resize(): 'Attribute method=nearest is not supported') else: attrs['align_corners'] = True - attrs['method'] = 'BILINEAR' + attrs['method'] = 'bilinear' out = None begin = [0] * data_dim @@ -408,7 +408,7 @@ def _resize_bilinear(): return AttrCvt(op_name="resize", ignores=['Tdim'], - extras={'method': "BILINEAR"})(inputs, attr) + extras={'method': "bilinear"})(inputs, attr) return _impl def _resize_nearest_neighbor(): @@ -423,7 +423,7 @@ def _resize_nearest_neighbor(): return AttrCvt(op_name="resize", ignores=['Tdim'], - extras={'method': "NEAREST_NEIGHBOR"})(inputs, attr) + extras={'method': "nearest_neighbor"})(inputs, attr) return _impl def _check_numerics(): diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 162cc36..f4c10f2 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -262,7 +262,7 @@ class OperatorConverter(object): # Options - align_corners (bool) resize_options = None align_corners = False - if method == "BILINEAR": + if method == "bilinear": assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions resize_options = ResizeBilinearOptions() elif tflite_ver >= 1130: @@ -280,11 +280,11 @@ class OperatorConverter(object): def convert_resize_bilinear(self, op): """Convert TFLite RESIZE_BILINEAR""" - return self._convert_resize("BILINEAR", op) + return self._convert_resize("bilinear", op) def convert_resize_nearest_neighbor(self, op): """Convert TFLite RESIZE_NEAREST_NEIGHBOR""" - return self._convert_resize("NEAREST_NEIGHBOR", op) + return self._convert_resize("nearest_neighbor", op) def convert_logistic(self, op): """Convert TFLite LOGISTIC""" diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index cd0df33..fcebfd8 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -17,7 +17,20 @@ #pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" from __future__ import absolute_import -from ..op import register_schedule, schedule_injective + +import topi +from .. import op as reg +from ..op import schedule_injective # resize -register_schedule("image.resize", schedule_injective) +reg.register_schedule("image.resize", schedule_injective) + + +@reg.register_compute("image.resize") +def compute_resize(attrs, inputs, out_type, target): + size = attrs.size + layout = attrs.layout + method = attrs.method + align_corners = attrs.align_corners + out_dtype = attrs.out_dtype + return [topi.image.resize(inputs[0], size, layout, method, align_corners, out_dtype)] diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 2861bd5..c54e438 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -21,8 +21,9 @@ from . import _make def resize(data, size, layout="NCHW", - method="BILINEAR", - align_corners=False): + method="bilinear", + align_corners=True, + out_dtype=None): """Image resize operator. This operator takes data as input and does 2D scaling to the given scale factor. @@ -31,7 +32,7 @@ def resize(data, out will have a shape (n, c, size[0], size[1]) method indicates the algorithm to be used while calculating ghe out value - and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") + and method can be one of ("bilinear", "nearest_neighbor", "bicubic") Parameters ---------- @@ -45,14 +46,17 @@ def resize(data, Layout of the input. method : str, optional - Scale method to used [NEAREST_NEIGHBOR, BILINEAR]. + Scale method to used [nearest_neighbor, bilinear, bicubic]. align_corners : int, optional Should be true to preserve the values at the corner pixels + out_dtype : str, optional + Type to return. If left None returns the same type as input. + Returns ------- result: relay.Expr The resized result. """ - return _make.resize(data, size, layout, method, align_corners) + return _make.resize(data, size, layout, method, align_corners, out_dtype) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 6dcff6b..9b4caa1 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -376,6 +376,13 @@ def schedule_upsampling(_, outs, target): with target: return topi.generic.schedule_injective(outs) +@reg.register_compute("nn.upsampling") +def compute_upsampling(attrs, inputs, out_dtype, target): + scale = attrs.scale + layout = attrs.layout + method = attrs.method + align_corners = attrs.align_corners + return [topi.nn.upsampling(inputs[0], scale, layout, method, align_corners)] # pad reg.register_schedule("nn.pad", schedule_broadcast) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b311aa9..4b7f52e 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -481,7 +481,8 @@ def global_avg_pool2d(data, def upsampling(data, scale=1, layout="NCHW", - method="NEAREST_NEIGHBOR"): + method="nearest_neighbor", + align_corners=False): """Upsampling. This operator takes data as input and does 2D scaling to the given scale factor. @@ -490,7 +491,7 @@ def upsampling(data, out will have a shape (n, c, h*scale, w*scale) method indicates the algorithm to be used while calculating the out value - and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") + and method can be one of ("bilinear", "nearest_neighbor", "bicubic") Parameters ---------- @@ -504,14 +505,17 @@ def upsampling(data, Layout of the input. method : str, optional - Scale method to used [NEAREST_NEIGHBOR, BILINEAR]. + Scale method to used [nearest_neighbor, bilinear, bicubic]. + + align_corners : bool, optional + Whether to keep corners in proper place. Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.upsampling(data, scale, layout, method) + return _make.upsampling(data, scale, layout, method, align_corners) def batch_flatten(data): diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index ffa489e..dbdf897 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -25,8 +25,6 @@ #include #include #include -#include -#include #include "../op_common.h" namespace tvm { @@ -56,49 +54,32 @@ bool ResizeRel(const Array& types, oshape.Set(2, param->size[0]); oshape.Set(3, param->size[1]); + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + // assign output type reporter->Assign(types[1], TensorTypeNode::make(layout_converter.BackwardShape(oshape), - data->dtype)); + out_dtype)); return true; } -Array ResizeCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type, - const Target& target) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - CHECK(param->layout == "NCHW" || param->layout == "NHWC"); - const auto* out_ttype = out_type.as(); - CHECK(out_ttype != nullptr); - Array oshape; - if (param->layout == "NCHW") { - oshape.push_back(out_ttype->shape[2]); - oshape.push_back(out_ttype->shape[3]); - } else if (param->layout == "NHWC") { - oshape.push_back(out_ttype->shape[1]); - oshape.push_back(out_ttype->shape[2]); - } - return Array{ topi::image::resize(inputs[0], - oshape, - param->layout, - param->align_corners, - param->method) }; -} - // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize(Expr data, Array size, std::string layout, std::string method, - bool align_corners) { + bool align_corners, + DataType out_dtype) { auto attrs = make_node(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->align_corners = align_corners; + attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); return CallNode::make(op, {data}, Attrs(attrs), {}); } @@ -127,7 +108,6 @@ RELAY_REGISTER_OP("image.resize") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) .add_type_rel("Resize", ResizeRel) -.set_attr("FTVMCompute", ResizeCompute) .set_attr("TOpPattern", kInjective); } // namespace relay diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 1ee668a..9989203 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -27,8 +27,6 @@ #include #include #include -#include -#include #include #include "../op_common.h" @@ -99,11 +97,13 @@ bool UpSamplingRel(const Array& types, Expr MakeUpSampling(Expr data, int scale, std::string layout, - std::string method) { + std::string method, + bool align_corners) { auto attrs = make_node(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->scale = scale; + attrs->align_corners = align_corners; static const Op& op = Op::Get("nn.upsampling"); return CallNode::make(op, {data}, Attrs(attrs), {}); } @@ -135,38 +135,7 @@ RELAY_REGISTER_OP("nn.upsampling") .add_type_rel("UpSampling", UpSamplingRel) .set_attr("FInferCorrectLayout", UpsamplingInferCorrectLayout) -.set_attr("TOpPattern", kInjective) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type, - const Target& target) { - const auto* uattrs = attrs.as(); - CHECK(uattrs != nullptr); - auto out_tt = out_type.as(); - CHECK(out_tt) << "expected a tensor type: " << out_type; - const auto layout = uattrs->layout; - const auto base_layout = layout.substr(0, 4); - CHECK(base_layout == "NCHW" || layout == "NHWC") - << "unknown layout: " << uattrs->layout; - - Array oshape; - if (base_layout == "NCHW") { - oshape.push_back(out_tt->shape[2]); - oshape.push_back(out_tt->shape[3]); - } else if (layout == "NHWC") { - oshape.push_back(out_tt->shape[1]); - oshape.push_back(out_tt->shape[2]); - } - - return Array{ - topi::nn::upsampling( - inputs[0], - oshape, - uattrs->layout, - uattrs->method) - }; -}); +.set_attr("TOpPattern", kInjective); } // namespace relay diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 7658db2..f571370 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -172,7 +172,7 @@ def test_forward_upsample(interpolation='nearest'): data = keras.layers.Input(shape=(32, 32, 3)) x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data) keras_model = keras.models.Model(data, x) - verify_keras_frontend(keras_model) + verify_keras_frontend(keras_model, need_transpose=False) def test_forward_reshape(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 7b0bcfb..eb8e27e 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1212,7 +1212,7 @@ def test_forward_crop_and_resize(): _test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5]) _test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5]) _test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4]) - _test_forward_crop_and_resize([1, 106, 106, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3]) + _test_forward_crop_and_resize([1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3]) _test_forward_crop_and_resize([10, 11, 11, 3], [[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]], [0, 1], diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 4e8fe2c..5e9abdf 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -233,13 +233,13 @@ def test_conv2d_transpose_run(): def test_upsampling_infer_type(): n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR") + y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear") "method=\"BINLINEAR\"" in y.astext() yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32") n, c = tvm.var("n"), tvm.var("c") x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) - y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR") + y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear") yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") @@ -502,7 +502,7 @@ def test_batch_flatten(): np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) -def _test_upsampling(layout, method): +def _test_upsampling(layout, method, align_corners=False): n, c, h, w = tvm.var("n"), 16, 32, 32 scale = 2 dtype = "float32" @@ -513,15 +513,17 @@ def _test_upsampling(layout, method): return (h, w, c), (h*scale, w*scale, c) ishape, oshape = get_shape() x = relay.var("x", relay.TensorType((n,) + ishape, dtype)) - y = relay.nn.upsampling(x, scale=scale, layout=layout, method=method) + y = relay.nn.upsampling(x, scale=scale, layout=layout, + method=method, align_corners=align_corners) yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n,) + oshape, dtype) dshape = (1,) + ishape x = relay.var("x", shape=dshape) - y = relay.nn.upsampling(x, scale=scale, layout=layout, method=method) + y = relay.nn.upsampling(x, scale=scale, layout=layout, + method=method, align_corners=align_corners) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) - if method == "NEAREST_NEIGHBOR": + if method == "nearest_neighbor": ref = topi.testing.upsampling_python(data, (scale, scale), layout) else: ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout) @@ -532,10 +534,10 @@ def _test_upsampling(layout, method): def test_upsampling(): - _test_upsampling("NCHW", "NEAREST_NEIGHBOR") - _test_upsampling("NCHW", "BILINEAR") - _test_upsampling("NHWC", "NEAREST_NEIGHBOR") - _test_upsampling("NHWC", "BILINEAR") + _test_upsampling("NCHW", "nearest_neighbor") + _test_upsampling("NCHW", "bilinear", True) + _test_upsampling("NHWC", "nearest_neighbor") + _test_upsampling("NHWC", "bilinear", True) def test_conv2d_int8_intrinsics(): diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 328e4d5..f4ac673 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -39,7 +39,7 @@ def test_resize_infer_type(): assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) - z= relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False) + z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", True) assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") @@ -52,12 +52,12 @@ def test_resize(): size = (dshape[2] * scale, dshape[3] * scale) x_data = np.random.uniform(size=dshape).astype("float32") - if method == "BILINEAR": + if method == "bilinear": ref_res = topi.testing.bilinear_resize_python(x_data, size, layout) else: ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout) x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.image.resize(x, size, layout, method, False) + z = relay.image.resize(x, size, layout, method, True) assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") @@ -68,7 +68,7 @@ def test_resize(): intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - for method in ["BILINEAR", "NEAREST_NEIGHBOR"]: + for method in ["bilinear", "nearest_neighbor"]: for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout) diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index 7d7143c..e44f5a7 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -201,7 +201,6 @@ inline Tensor resize_nearest_neighbor(const Tensor& input, bool align_corners = false, std::string name = "tensor", std::string tag = kInjective) { - CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour"; auto base_layout = layout.substr(0, 4); if (layout == "NHWC") { return resize_nearest_neighbor_nhwc(input, shape, align_corners); diff --git a/topi/python/topi/image/resize.py b/topi/python/topi/image/resize.py index 60fcf37..7a24990 100644 --- a/topi/python/topi/image/resize.py +++ b/topi/python/topi/image/resize.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """TVM operator input resize compute.""" from __future__ import absolute_import -import topi +import tvm +from .. import tag -def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"): + +def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None): """Perform resize operation on the data. Parameters @@ -32,18 +35,178 @@ def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"): Output resolution scale to layout: string, optional - either "NCHW" or "NHWC" + "NCHW", "NHWC", or "NCHWc". align_corners: Boolean, optional - To preserve the values at the corner pixels + To preserve the values at the corner pixels. - method: {"BILINEAR", "NEAREST_NEIGHBOR"} + method: {"bilinear", "nearest_neighbor", "bicubic"} Method to be used for resizing. + out_dtype: string, optional + Type to return. If left None will be same as input type. + Returns ------- output : tvm.Tensor 4-D with shape [batch, channel, in_height*scale, in_width*scale] or [batch, in_height*scale, in_width*scale, channel] + or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor] """ - return topi.cpp.image.resize(data, size, layout, align_corners, method) + method = method.lower() + + if layout == 'NHWC': + in_n, in_h, in_w, in_c = data.shape + output_shape = [in_n, size[0], size[1], in_c] + elif layout == 'NCHW': + in_n, in_c, in_h, in_w = data.shape + output_shape = [in_n, in_c, size[0], size[1]] + # Otherwise layout must be NCHWxc + else: + in_n, in_c, in_h, in_w, in_cc = data.shape + output_shape = [in_n, in_c, size[0], size[1], in_cc] + + if align_corners: + y_ratio = (in_h - 1).astype('float') / (size[0] - 1) + x_ratio = (in_w - 1).astype('float') / (size[1] - 1) + else: + y_ratio = (in_h).astype('float') / (size[0]) + x_ratio = (in_w).astype('float') / (size[1]) + + def _get_pixel(n, c, y, x, cc): + y = tvm.max(tvm.min(y, in_h - 1), 0) + x = tvm.max(tvm.min(x, in_w - 1), 0) + if layout == 'NHWC': + return data(n, y, x, c).astype('float') + if layout == 'NCHW': + return data(n, c, y, x).astype('float') + # else must be NCHWxc + return data(n, c, y, x, cc).astype('float') + + def _get_indices(*indices): + if layout == 'NHWC': + n, y, x, c = indices + cc = None + elif layout == 'NCHW': + n, c, y, x = indices + cc = None + else: + n, c, y, x, cc = indices + + return n, c, y, x, cc + + def _cast_output(value): + if out_dtype: + dtype = out_dtype + else: + dtype = data.dtype + return value.astype(dtype) + + # Nearest neighbor computation + def _nearest_neighbor(*indices): + n, c, y, x, cc = _get_indices(*indices) + + in_y = y_ratio * y + in_x = x_ratio * x + + if align_corners: + yint = tvm.round(in_y).astype('int32') + xint = tvm.round(in_x).astype('int32') + else: + # Add epsilon to floor to prevent gpu rounding errors. + epsilon = 1e-5 + yint = tvm.floor(in_y + epsilon).astype('int32') + xint = tvm.floor(in_x + epsilon).astype('int32') + + return _cast_output(_get_pixel(n, c, yint, xint, cc)) + + # Bilinear helper functions and computation. + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + + def _bilinear(*indices): + n, c, y, x, cc = _get_indices(*indices) + + in_y = y_ratio * y + in_x = x_ratio * x + + xint = tvm.floor(in_x).astype('int32') + xfract = in_x - tvm.floor(in_x) + + yint = tvm.floor(in_y).astype('int32') + yfract = in_y - tvm.floor(in_y) + + p00 = _get_pixel(n, c, yint, xint, cc) + p10 = _get_pixel(n, c, yint, xint + 1, cc) + p01 = _get_pixel(n, c, yint + 1, xint, cc) + p11 = _get_pixel(n, c, yint + 1, xint + 1, cc) + + col0 = _lerp(p00, p10, xfract) + col1 = _lerp(p01, p11, xfract) + value = _lerp(col0, col1, yfract) + return _cast_output(value) + + # Bicubic helper function and computation. + def _cubic_kernel(A, B, C, D, t): + a = -A / 2.0 + (3.0*B) / 2.0 - (3.0*C) / 2.0 + D / 2.0 + b = A - (5.0*B) / 2.0 + 2.0*C - D / 2.0 + c = -A / 2.0 + C / 2.0 + d = B + + return a*t*t*t + b*t*t + c*t + d + + def _bicubic(*indices): + n, c, y, x, cc = _get_indices(*indices) + + in_y = y_ratio * y + in_x = x_ratio * x + + xint = tvm.floor(in_x).astype('int32') + xfract = in_x - tvm.floor(in_x) + + yint = tvm.floor(in_y).astype('int32') + yfract = in_y - tvm.floor(in_y) + + # 1st row + p00 = _get_pixel(n, c, yint - 1, xint - 1, cc) + p10 = _get_pixel(n, c, yint - 1, xint + 0, cc) + p20 = _get_pixel(n, c, yint - 1, xint + 1, cc) + p30 = _get_pixel(n, c, yint - 1, xint + 2, cc) + + # 2nd row + p01 = _get_pixel(n, c, yint + 0, xint - 1, cc) + p11 = _get_pixel(n, c, yint + 0, xint + 0, cc) + p21 = _get_pixel(n, c, yint + 0, xint + 1, cc) + p31 = _get_pixel(n, c, yint + 0, xint + 2, cc) + + # 3rd row + p02 = _get_pixel(n, c, yint + 1, xint - 1, cc) + p12 = _get_pixel(n, c, yint + 1, xint + 0, cc) + p22 = _get_pixel(n, c, yint + 1, xint + 1, cc) + p32 = _get_pixel(n, c, yint + 1, xint + 2, cc) + + # 4th row + p03 = _get_pixel(n, c, yint + 2, xint - 1, cc) + p13 = _get_pixel(n, c, yint + 2, xint + 0, cc) + p23 = _get_pixel(n, c, yint + 2, xint + 1, cc) + p33 = _get_pixel(n, c, yint + 2, xint + 2, cc) + + # Interpolate bicubically + col0 = _cubic_kernel(p00, p10, p20, p30, xfract) + col1 = _cubic_kernel(p01, p11, p21, p31, xfract) + col2 = _cubic_kernel(p02, p12, p22, p32, xfract) + col3 = _cubic_kernel(p03, p13, p23, p33, xfract) + value = _cubic_kernel(col0, col1, col2, col3, yfract) + return _cast_output(value) + + # Determine which interpolation method to use then run it. + if method == "nearest_neighbor": + compute_func = _nearest_neighbor + elif method == "bilinear": + compute_func = _bilinear + elif method == "bicubic": + compute_func = _bicubic + else: + raise ValueError('%s method is not supported.' % method) + + return tvm.compute(output_shape, compute_func, name='resize', tag=tag.INJECTIVE) diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 7926df2..6092136 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -20,7 +20,7 @@ import topi from ..util import simplify -def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): +def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corners=False): """Perform upsampling on the data. Nearest neighbor and bilinear upsampling are supported. @@ -37,7 +37,7 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): layout : string, optional either "NCHW" or "NHWC" - method : {"BILINEAR", "NEAREST_NEIGHBOR"} + method : {"bilinear", "nearest_neighbor", "bicubic"} Method to be used for upsampling. Returns @@ -53,4 +53,5 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale)) else: raise ValueError("not support this layout {} yet".format(layout)) - return topi.cpp.nn.upsampling(data, out_shape, layout, method) + return topi.image.resize(data, out_shape, layout=layout, + method=method, align_corners=align_corners) diff --git a/topi/python/topi/testing/bilinear_resize_python.py b/topi/python/topi/testing/bilinear_resize_python.py index 6764ed3..86dd450 100644 --- a/topi/python/topi/testing/bilinear_resize_python.py +++ b/topi/python/topi/testing/bilinear_resize_python.py @@ -19,7 +19,7 @@ import math import numpy as np -def bilinear_resize_python(image, out_size, layout, align_corners=False): +def bilinear_resize_python(image, out_size, layout, align_corners=True): """ Bilinear scaling using python""" (new_h, new_w) = out_size diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_resize.py index 8277886..7c33526 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_resize.py @@ -23,7 +23,7 @@ import math from common import get_all_backend -def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False, method="BILINEAR"): +def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=True, method="bilinear"): if layout == 'NCHW': A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32') dtype = A.dtype @@ -40,7 +40,7 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method) - if method == "BILINEAR": + if method == "bilinear": b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) else: scale_h = out_height / in_height @@ -76,8 +76,8 @@ def test_resize(): # Scale NHWC + Align Corners verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True) # Nearest + Fractional - verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="NEAREST_NEIGHBOR") - verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="NEAREST_NEIGHBOR") + verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False) + verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False) if __name__ == "__main__": test_resize() diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index ddfb002..f878c23 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -23,7 +23,7 @@ import math from common import get_all_backend -def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"): +def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="nearest_neighbor"): if layout == 'NCHW': @@ -40,11 +40,11 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH raise NotImplementedError( 'Layout not supported {} '.format(layout)) - B = topi.nn.upsampling(A, scale, layout=layout, method=method) + B = topi.nn.upsampling(A, scale, layout=layout, method=method, align_corners=False) - if method == "BILINEAR": + if method == "bilinear": out_size = (in_height*scale, in_width*scale) - b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout) + b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False) else: b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout) @@ -67,21 +67,21 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH check_device(device) def test_upsampling(): - # NEAREST_NEIGHBOR - NCHW + # nearest_neighbor - NCHW verify_upsampling(8, 16, 32, 32, 2) verify_upsampling(2, 32, 64, 64, 3) - # NEAREST_NEIGHBOR - NHWC + ## nearest_neighbor - NHWC verify_upsampling(8, 16, 32, 32, 2, layout="NHWC") verify_upsampling(2, 32, 64, 64, 3, layout="NHWC") - # BILINEAR - NCHW - verify_upsampling(2, 2, 32, 32, 2, method="BILINEAR") - verify_upsampling(2, 2, 32, 32, 3, method="BILINEAR") + # bilinear - NCHW + verify_upsampling(2, 2, 32, 32, 2, method="bilinear") + verify_upsampling(2, 2, 32, 32, 3, method="bilinear") - # BILINEAR - NHWC - verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="BILINEAR") - verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="BILINEAR") + # bilinear - NHWC + verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="bilinear") + verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="bilinear") if __name__ == "__main__": test_upsampling() -- 2.7.4