From e369c5a9cbacb926ca7b95ebc4ae01a6de33c6cd Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sat, 23 May 2020 00:57:58 -0400 Subject: [PATCH] [Relay,Topi][OP] affine_grid and grid_sample (#5657) * [Relay,Topi][OP] affine_grid and grid_sample * lint --- include/tvm/relay/attrs/image.h | 28 +++++ python/tvm/relay/frontend/mxnet.py | 22 ++++ python/tvm/relay/op/image/_image.py | 20 +++ python/tvm/relay/op/image/image.py | 64 ++++++++++ src/relay/op/image/grid_sample.cc | 168 +++++++++++++++++++++++++ tests/python/frontend/mxnet/test_forward.py | 38 ++++++ tests/python/relay/test_op_level5.py | 52 ++++++++ topi/python/topi/image/__init__.py | 1 + topi/python/topi/image/grid_sample.py | 124 ++++++++++++++++++ topi/python/topi/testing/__init__.py | 1 + topi/python/topi/testing/grid_sample_python.py | 65 ++++++++++ topi/tests/python/test_topi_image.py | 83 ++++++++++++ 12 files changed, 666 insertions(+) create mode 100644 src/relay/op/image/grid_sample.cc create mode 100644 topi/python/topi/image/grid_sample.py create mode 100644 topi/python/topi/testing/grid_sample_python.py diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 58fd44b..cf5a6ef 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -167,6 +167,34 @@ struct Dilation2DAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in image affine_grid operator */ +struct AffineGridAttrs : public tvm::AttrsNode { + Array target_shape; + + TVM_DECLARE_ATTRS(AffineGridAttrs, "relay.attrs.AffineGridAttrs") { + TVM_ATTR_FIELD(target_shape).describe("Specifies the output shape (H, W)."); + } +}; + +/*! \brief Attributes used in image grid_sample operator */ +struct GridSampleAttrs : public tvm::AttrsNode { + String method; + String layout; + + TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") { + TVM_ATTR_FIELD(method) + .set_default("bilinear") + .describe( + "Specify the mode to use for scaling." + "bilinear - Bilinear Interpolation"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_IMAGE_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 9f97ee9..c75612d 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -757,6 +757,26 @@ def _mx_resize(inputs, attrs): return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") +def _mx_grid_generator(inputs, attrs): + transform_type = attrs.get_str("transform_type") + if transform_type == 'affine': + target_shape = attrs.get_int_tuple("target_shape") + return _op.image.affine_grid(_op.reshape(inputs[0], (0, 2, 3)), target_shape) + if transform_type == 'warp': + checked_type = _infer_type(inputs[0]).checked_type + batch, _, height, width = get_const_tuple(checked_type.shape) + dtype = checked_type.dtype + identity_affine = relay.const(np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=dtype)) + identity_affine = _op.broadcast_to(identity_affine, (batch, 2, 3)) + normalizer = (2.0 / np.array([width - 1, height - 1])).reshape(1, -1, 1, 1).astype(dtype) + normalized_flow = inputs[0] * relay.const(normalizer) + grid = _op.image.affine_grid(identity_affine, (height, width)) + return grid + normalized_flow + raise ValueError("unknown transform type" + transform_type) + +def _mx_bilinear_sampler(inputs, attrs): + return _op.image.grid_sample(inputs[0], inputs[1], 'bilinear', 'NCHW') + def _mx_roi_pooling(inputs, attrs): new_attrs = {} new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size") @@ -1996,6 +2016,8 @@ _convert_map = { "_contrib_box_nms" : _mx_box_nms, "_contrib_DeformableConvolution" : _mx_deformable_convolution, "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling, + "GridGenerator" : _mx_grid_generator, + "BilinearSampler" : _mx_bilinear_sampler, # NLP "RNN" : _mx_rnn_layer, "_rnn_param_concat" : _mx_rnn_param_concat, diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 290c0a2..bcb110f 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -19,6 +19,7 @@ from __future__ import absolute_import import topi +from topi.util import get_const_tuple from .. import op as reg from .. import strategy from ..op import OpPattern @@ -67,3 +68,22 @@ reg.register_injective_schedule("image.crop_and_resize") # dilation2d reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy) reg.register_pattern("image.dilation2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# affine_grid +@reg.register_compute("image.affine_grid") +def compute_affine_grid(attrs, inputs, out_dtype): + target_shape = get_const_tuple(attrs.target_shape) + return [topi.image.affine_grid(inputs[0], target_shape)] + +reg.register_injective_schedule("image.affine_grid") + + +# grid_sample +@reg.register_compute("image.grid_sample") +def compute_grid_sample(attrs, inputs, out_dtype): + method = attrs.method + layout = attrs.layout + return [topi.image.grid_sample(inputs[0], inputs[1], method, layout)] + +reg.register_injective_schedule("image.grid_sample") diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 49b35d8..62889e0 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -215,3 +215,67 @@ def dilation2d(data, return _make.dilation2d(data, weight, strides, padding, dilations, data_layout, kernel_layout, out_dtype) + + +def affine_grid(data, target_shape=None): + """affine_grid operator that generates 2D sampling grid. + + This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform + sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine + transformation is then applied on the sampling grid. + + Parameters + ---------- + data : tvm.Tensor + 3-D with shape [batch, 2, 3]. The affine matrix. + + target_shape: list/tuple of two int + Specifies the output shape (H, W). + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, 2, target_height, target_width] + """ + return _make.affine_grid(data, target_shape) + +def grid_sample(data, grid, method='bilinear', layout='NCHW'): + """Applies bilinear sampling to input feature map. + + Given :math:`data` and :math:`grid`, then the output is computed by + + .. math:: + + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + + :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and + :math:`G()` denotes the interpolation function. + The out-boundary points will be padded with zeros. The shape of the output will be + (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + + The operator assumes that :math:`grid` has been normalized to [-1, 1]. + + grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + grid : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width] + + method : str + The interpolation method. Only 'bilinear' is supported. + + layout : str + The layout of input data and the output. + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width] + """ + return _make.grid_sample(data, grid, method, layout) diff --git a/src/relay/op/image/grid_sample.cc b/src/relay/op/image/grid_sample.cc new file mode 100644 index 0000000..bc69891 --- /dev/null +++ b/src/relay/op/image/grid_sample.cc @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file grid_sample.cc + * \brief affine_grid and grid_sample operator + */ +#include +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relay { + +// relay.image.affine_grid +TVM_REGISTER_NODE_TYPE(AffineGridAttrs); + +bool AffineGridRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + auto batch_size = data->shape[0]; + + const AffineGridAttrs* param = attrs.as(); + CHECK(param != nullptr); + + Array oshape; + + CHECK(data->shape.size() == 3U && reporter->AssertEQ(data->shape[1], 2) && + reporter->AssertEQ(data->shape[2], 3)) + << "data should be an" + "affine matrix with shape [batch_size, 2, 3]"; + CHECK(param->target_shape.defined() && param->target_shape.size() == 2) + << "target_shape should be 2D"; + oshape.push_back(batch_size); + oshape.push_back(2); + oshape.push_back(param->target_shape[0]); + oshape.push_back(param->target_shape[1]); + + // assign output type + reporter->Assign(types[1], TensorType(oshape, data->dtype)); + return true; +} + +// Positional relay function to create affine_grid operator +// used by frontend FFI. +Expr MakeAffineGrid(Expr data, Array target_shape) { + auto attrs = make_object(); + attrs->target_shape = std::move(target_shape); + static const Op& op = Op::Get("image.affine_grid"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.affine_grid").set_body_typed(MakeAffineGrid); + +RELAY_REGISTER_OP("image.affine_grid") + .describe(R"code(affine_grid operator that generates 2D sampling grid. + +This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform +sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine +transformation is then applied on the sampling grid. + +- **data**: data is 3D array of shape [batch, 2, 3], which defines an affine transformation. + +- **out**: out is 4D array of shape [batch, 2, height, width], where each vector + :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)` + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The affine matrix.") + .set_support_level(5) + .add_type_rel("AffineGrid", AffineGridRel) + .set_attr("TOpPattern", kInjective); + +// relay.image.grid_sample +TVM_REGISTER_NODE_TYPE(GridSampleAttrs); + +bool GridSampleRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* grid = types[1].as(); + if (!data || !grid) return false; + const auto* param = attrs.as(); + CHECK(param); + static const Layout kNCHW("NCHW"); + const Layout in_layout(param->layout); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, grid->shape[2]); + oshape.Set(3, grid->shape[3]); + // assign output type + reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + return true; +} + +// Positional relay function to create affine_grid operator +// used by frontend FFI. +Expr MakeGridSample(Expr data, Expr grid, String method, String layout) { + auto attrs = make_object(); + attrs->method = std::move(method); + attrs->layout = std::move(layout); + static const Op& op = Op::Get("image.grid_sample"); + return Call(op, {data, grid}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.grid_sample").set_body_typed(MakeGridSample); + +RELAY_REGISTER_OP("image.grid_sample") + .describe(R"code(Applies grid sampling to input feature map. + +Given :math:`data` and :math:`grid`, then the output is computed by + +.. math:: + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + +:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and +:math:`G()` denotes the interpolation function. +The out-boundary points will be padded with zeros. The shape of the output will be +(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + +The operator assumes that :math:`data` has 'NCHW' layout and :math:`grid` has been normalized to [-1, 1]. + +grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + +- **data**: data is 4D array of shape + (batch_size, channels, in_height, in_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +- **grid**: out is 4D array of shape [batch, 2, out_height, out_width], where each vector + :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)` + +- **out**: out is 4D array of shape + (batch, in_channel, out_height, out_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("GridSample", GridSampleRel) + .set_attr("TOpPattern", kInjective); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 99fc6c3..6d36ea3 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -639,6 +639,42 @@ def test_forward_bilinear_resize(): mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10) verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) +def test_forward_grid_generator(): + def verify(shape, transform_type, target_shape): + x = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.GridGenerator(mx.nd.array(x), transform_type, target_shape) + mx_sym = mx.sym.GridGenerator(mx.sym.var("x"), transform_type, target_shape) + shape_dict = {"x": x.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor( + kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + verify((4, 6), 'affine', (16, 32)) + verify((4, 2, 16, 16), 'warp', None) + verify((1, 2, 16, 16), 'warp', None) + +def test_forward_bilinear_sampler(): + def verify(data_shape, grid_shape): + data = np.random.uniform(size=data_shape).astype("float32") + grid = np.random.uniform(low=-1.5, high=1.5, size=grid_shape).astype("float32") + ref_res = mx.nd.BilinearSampler(mx.nd.array(data), mx.nd.array(grid)) + mx_sym = mx.sym.BilinearSampler(mx.sym.var("data"), mx.sym.var("grid")) + shape_dict = {"data": data.shape, "grid": grid.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor( + kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data, grid) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + verify((4, 4, 16, 32), (4, 2, 8, 8)) + verify((4, 4, 16, 32), (4, 2, 32, 32)) + def test_forward_rnn_layer(): def verify(mode, seq_len, input_size, hidden_size, num_layers, batch=1, init_states=True, bidirectional=False): @@ -1211,3 +1247,5 @@ if __name__ == '__main__': test_forward_unravel_index() test_forward_swap_axis() test_forward_correlation() + test_forward_grid_generator() + test_forward_bilinear_sampler() diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index c9d7d42..c306752 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -823,6 +823,56 @@ def test_dilation2d_run(): data_layout='NHWC', kernel_layout='HWI') +def test_affine_grid(): + def verify_affine_grid(num_batch, target_shape): + dtype = 'float32' + data_shape = (num_batch, 2, 3) + data = relay.var("data", relay.ty.TensorType(data_shape, dtype)) + y = relay.image.affine_grid(data, target_shape) + yy = run_infer_type(y) + assert yy.checked_type == relay.ty.TensorType((num_batch, len(target_shape), *target_shape), dtype) + + func = relay.Function([data], y) + data_np = np.random.uniform(size=data_shape).astype(dtype) + ref_res = topi.testing.affine_grid_python(data_np, target_shape) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp1 = relay.create_executor(kind, ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data_np) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + verify_affine_grid(1, (16, 32)) + verify_affine_grid(4, (16, 32)) + + +def test_grid_sample(): + def verify_grid_sample(data_shape, grid_shape): + dtype = 'float32' + batch, channel, _, _ = data_shape + _, _, out_height, out_width = grid_shape + data = relay.var("data", relay.ty.TensorType(data_shape, dtype)) + grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype)) + y = relay.image.grid_sample(data, grid, method='bilinear', layout='NCHW') + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((batch, channel, out_height, out_width), dtype) + func = relay.Function([data, grid], y) + + data_np = np.random.uniform(size=data_shape).astype(dtype) + grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) + ref_res = topi.testing.grid_sample_nchw_python(data_np, grid_np, method='bilinear') + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp1 = relay.create_executor(kind, ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data_np, grid_np) + tvm.testing.assert_allclose( + op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) + verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32)) + + if __name__ == "__main__": test_resize_infer_type() test_resize() @@ -843,3 +893,5 @@ if __name__ == "__main__": test_space_to_depth() test_dilation2d_infer_type() test_dilation2d_run() + test_affine_grid() + test_grid_sample() diff --git a/topi/python/topi/image/__init__.py b/topi/python/topi/image/__init__.py index 86b9825..914b02e 100644 --- a/topi/python/topi/image/__init__.py +++ b/topi/python/topi/image/__init__.py @@ -21,3 +21,4 @@ from __future__ import absolute_import as _abs from .resize import * from .dilation2d import * +from .grid_sample import * diff --git a/topi/python/topi/image/grid_sample.py b/topi/python/topi/image/grid_sample.py new file mode 100644 index 0000000..32b6112 --- /dev/null +++ b/topi/python/topi/image/grid_sample.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""affine_grid and grid_sample operator""" +from tvm import te, tir + + +def affine_grid(data, target_shape): + """affine_grid operator that generates 2D sampling grid. + + This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform + sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine + transformation is then applied on the sampling grid. + + Parameters + ---------- + data : tvm.Tensor + 3-D with shape [batch, 2, 3]. The affine matrix. + + target_shape: list/tuple of two int + Specifies the output shape (H, W). + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, 2, target_height, target_width] + """ + assert target_shape is not None + assert len(target_shape) == 2 + assert target_shape[0] > 1 and target_shape[1] > 1, \ + "target height/width should be greater than 1" + + dtype = data.dtype + y_step = tir.const((2.0 - 1e-7)/ (target_shape[0] - 1), dtype=dtype) + x_step = tir.const((2.0 - 1e-7)/ (target_shape[1] - 1), dtype=dtype) + start = tir.const(-1.0, dtype=dtype) + + def _compute(n, dim, i, j): + y = start + i * y_step + x = start + j * x_step + return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2] + + oshape = (data.shape[0], len(target_shape), *target_shape) + return te.compute(oshape, _compute, tag='affine_grid') + + +def grid_sample(data, grid, method='bilinear', layout='NCHW'): + """Applies bilinear sampling to input feature map. + + Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output is computed by + + .. math:: + + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + + :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and + :math:`G()` denotes the interpolation method. + The out-boundary points will be padded with zeros. The shape of the output will be + (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + + The operator assumes that :math:`grid` has been normalized to [-1, 1]. + + grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + grid : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width] + + method : str + The interpolation method. Only 'bilinear' is supported. + + layout : str + The layout of input data and the output. + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width] + """ + batch, in_channel, in_height, in_width = data.shape + out_height, out_width = grid.shape[2:] + assert method == 'bilinear', "Only bilinear is supported" + assert layout == "NCHW", "Only NCHW is supported" + + def _get_pixel_value(n, c, h, w): + return te.if_then_else(te.all(h >= 0, w >= 0, h < in_height, w < in_width), + data[n, c, h, w], tir.const(0.0, dtype=data.dtype)) + + def _bilinear_sample(n, c, h, w): + x = grid[n, 0, h, w] + y = grid[n, 1, h, w] + y = (y + 1) * (in_height - 1) / 2 + x = (x + 1) * (in_width - 1) / 2 + x0 = te.floor(x).astype('int32') + y0 = te.floor(y).astype('int32') + x1 = x0 + tir.const(1, 'int32') + y1 = y0 + tir.const(1, 'int32') + return _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) \ + + _get_pixel_value(n, c, y0, x1) * (1.0 - (y - y0)) * (x - x0) \ + + _get_pixel_value(n, c, y1, x0) * (y - y0) * (1.0 - (x - x0)) \ + + _get_pixel_value(n, c, y1, x1) * (y - y0) * (x - x0) + + return te.compute((batch, in_channel, out_height, out_width), _bilinear_sample, + tag='grid_sample') diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 511fe16..e677a11 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -57,3 +57,4 @@ from .crop_and_resize_python import crop_and_resize_python from .common import get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, \ get_elemwise_schedule, get_conv2d_nchw_implement, dispatch from .adaptive_pool_python import adaptive_pool +from .grid_sample_python import affine_grid_python, grid_sample_nchw_python diff --git a/topi/python/topi/testing/grid_sample_python.py b/topi/python/topi/testing/grid_sample_python.py new file mode 100644 index 0000000..964d8a2 --- /dev/null +++ b/topi/python/topi/testing/grid_sample_python.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""affine_grid and grid_sample operators in python""" +import math +import numpy as np + + +def affine_grid_python(data, target_shape): + yv, xv = np.meshgrid( + np.arange(target_shape[0]), np.arange(target_shape[1])) + yv = yv.T * 2 / (target_shape[0] - 1) - 1 + xv = xv.T * 2 / (target_shape[1] - 1) - 1 + ones = np.ones_like(xv) + grid = np.stack([xv, yv, ones]).reshape(3, -1) + return data.reshape(-1, 3).dot(grid).reshape(data.shape[0], 2, *target_shape) + + +def _bilinear_sample_nchw_python(data, grid): + batch, in_channel, in_height, in_width = data.shape + _, _, out_height, out_width = grid.shape + out = np.zeros((batch, in_channel, out_height, out_width), dtype=data.dtype) + + def _within_bound(y, x): + return 0 <= y < in_height and 0 <= x < in_width + + for n in range(0, batch): + for h in range(0, out_height): + for w in range(0, out_width): + x, y = grid[n, :, h, w] + y = (y + 1) * (in_height - 1) / 2 + x = (x + 1) * (in_width - 1) / 2 + y0 = int(math.floor(y)) + x0 = int(math.floor(x)) + y1 = y0 + 1 + x1 = x0 + 1 + if _within_bound(y0, x0): + out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0)) + if _within_bound(y0, x1): + out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0) + if _within_bound(y1, x0): + out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0)) + if _within_bound(y1, x1): + out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0) + return out + + +def grid_sample_nchw_python(data, grid, method='bilinear'): + if method == 'bilinear': + return _bilinear_sample_nchw_python(data, grid) + raise ValueError("invalid method") diff --git a/topi/tests/python/test_topi_image.py b/topi/tests/python/test_topi_image.py index 4eea75d..012ed42 100644 --- a/topi/tests/python/test_topi_image.py +++ b/topi/tests/python/test_topi_image.py @@ -20,6 +20,7 @@ import tvm from tvm import te import topi import topi.testing +from tvm.contrib.pickle_memoize import memoize from common import get_all_backend @@ -204,7 +205,89 @@ def test_crop_and_resize(): size_1, method='nearest_neighbor') verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW") + +def test_affine_grid(): + def verify_affine_grid(num_batch, target_shape): + dtype = "float32" + data_shape = (num_batch, 2, 3) + data = te.placeholder(data_shape, dtype=dtype) + out = topi.image.affine_grid(data, target_shape) + + @memoize("topi.tests.test_affine_grid.verify_affine_grid") + def get_ref_data(): + data_np = np.random.uniform(size=data_shape).astype(dtype) + out_np = topi.testing.affine_grid_python(data_np, target_shape) + return data_np, out_np + + data_np, out_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(out) + tvm_data = tvm.nd.array(data_np, ctx) + tvm_out = tvm.nd.empty(out_np.shape, dtype, ctx) + f = tvm.build(s, [data, out], device) + f(tvm_data, tvm_out) + + tvm.testing.assert_allclose( + tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5) + + for device in get_all_backend(): + check_device(device) + + verify_affine_grid(1, (16, 32)) + verify_affine_grid(4, (16, 32)) + + +def test_grid_sample(): + def verify_grid_sample(data_shape, grid_shape): + dtype = "float32" + data = te.placeholder(data_shape, dtype=dtype) + grid = te.placeholder(grid_shape, dtype=dtype) + out = topi.image.grid_sample(data, grid, 'bilinear') + + @memoize("topi.tests.test_grid_sample.verify_grid_sample") + def get_ref_data(): + data_np = np.random.uniform(size=data_shape).astype(dtype) + # allow grid values to be out-of-bound + grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) + out_np = topi.testing.grid_sample_nchw_python(data_np, grid_np, 'bilinear') + return data_np, grid_np, out_np + + data_np, grid_np, out_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(out) + tvm_data = tvm.nd.array(data_np, ctx) + tvm_grid = tvm.nd.array(grid_np, ctx) + tvm_out = tvm.nd.empty(out_np.shape, dtype, ctx) + f = tvm.build(s, [data, grid, out], device) + f(tvm_data, tvm_grid, tvm_out) + + tvm.testing.assert_allclose( + tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5) + + for device in get_all_backend(): + check_device(device) + + verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) + verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32)) + + if __name__ == "__main__": test_resize() test_resize3d() test_crop_and_resize() + test_affine_grid() + test_grid_sample() -- 2.7.4