}
};
+/*! \brief Attributes used in image affine_grid operator */
+struct AffineGridAttrs : public tvm::AttrsNode<AffineGridAttrs> {
+ Array<IndexExpr> target_shape;
+
+ TVM_DECLARE_ATTRS(AffineGridAttrs, "relay.attrs.AffineGridAttrs") {
+ TVM_ATTR_FIELD(target_shape).describe("Specifies the output shape (H, W).");
+ }
+};
+
+/*! \brief Attributes used in image grid_sample operator */
+struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
+ String method;
+ String layout;
+
+ TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") {
+ TVM_ATTR_FIELD(method)
+ .set_default("bilinear")
+ .describe(
+ "Specify the mode to use for scaling."
+ "bilinear - Bilinear Interpolation");
+ TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Resize is applied on the 'H' and"
+ "'W' dimensions.");
+ }
+};
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_IMAGE_H_
return _op.image.resize(inputs[0], size,
coordinate_transformation_mode="align_corners")
+def _mx_grid_generator(inputs, attrs):
+ transform_type = attrs.get_str("transform_type")
+ if transform_type == 'affine':
+ target_shape = attrs.get_int_tuple("target_shape")
+ return _op.image.affine_grid(_op.reshape(inputs[0], (0, 2, 3)), target_shape)
+ if transform_type == 'warp':
+ checked_type = _infer_type(inputs[0]).checked_type
+ batch, _, height, width = get_const_tuple(checked_type.shape)
+ dtype = checked_type.dtype
+ identity_affine = relay.const(np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=dtype))
+ identity_affine = _op.broadcast_to(identity_affine, (batch, 2, 3))
+ normalizer = (2.0 / np.array([width - 1, height - 1])).reshape(1, -1, 1, 1).astype(dtype)
+ normalized_flow = inputs[0] * relay.const(normalizer)
+ grid = _op.image.affine_grid(identity_affine, (height, width))
+ return grid + normalized_flow
+ raise ValueError("unknown transform type" + transform_type)
+
+def _mx_bilinear_sampler(inputs, attrs):
+ return _op.image.grid_sample(inputs[0], inputs[1], 'bilinear', 'NCHW')
+
def _mx_roi_pooling(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
"_contrib_box_nms" : _mx_box_nms,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
"_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
+ "GridGenerator" : _mx_grid_generator,
+ "BilinearSampler" : _mx_bilinear_sampler,
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
from __future__ import absolute_import
import topi
+from topi.util import get_const_tuple
from .. import op as reg
from .. import strategy
from ..op import OpPattern
# dilation2d
reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy)
reg.register_pattern("image.dilation2d", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
+# affine_grid
+@reg.register_compute("image.affine_grid")
+def compute_affine_grid(attrs, inputs, out_dtype):
+ target_shape = get_const_tuple(attrs.target_shape)
+ return [topi.image.affine_grid(inputs[0], target_shape)]
+
+reg.register_injective_schedule("image.affine_grid")
+
+
+# grid_sample
+@reg.register_compute("image.grid_sample")
+def compute_grid_sample(attrs, inputs, out_dtype):
+ method = attrs.method
+ layout = attrs.layout
+ return [topi.image.grid_sample(inputs[0], inputs[1], method, layout)]
+
+reg.register_injective_schedule("image.grid_sample")
return _make.dilation2d(data, weight, strides, padding, dilations, data_layout,
kernel_layout, out_dtype)
+
+
+def affine_grid(data, target_shape=None):
+ """affine_grid operator that generates 2D sampling grid.
+
+ This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform
+ sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine
+ transformation is then applied on the sampling grid.
+
+ Parameters
+ ----------
+ data : tvm.Tensor
+ 3-D with shape [batch, 2, 3]. The affine matrix.
+
+ target_shape: list/tuple of two int
+ Specifies the output shape (H, W).
+
+ Returns
+ -------
+ Output : tvm.Tensor
+ 4-D with shape [batch, 2, target_height, target_width]
+ """
+ return _make.affine_grid(data, target_shape)
+
+def grid_sample(data, grid, method='bilinear', layout='NCHW'):
+ """Applies bilinear sampling to input feature map.
+
+ Given :math:`data` and :math:`grid`, then the output is computed by
+
+ .. math::
+
+ x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
+ y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
+ output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src})
+
+ :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
+ :math:`G()` denotes the interpolation function.
+ The out-boundary points will be padded with zeros. The shape of the output will be
+ (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
+
+ The operator assumes that :math:`grid` has been normalized to [-1, 1].
+
+ grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample.
+
+ Parameters
+ ----------
+ data : tvm.Tensor
+ 4-D with shape [batch, in_channel, in_height, in_width]
+
+ grid : tvm.Tensor
+ 4-D with shape [batch, 2, out_height, out_width]
+
+ method : str
+ The interpolation method. Only 'bilinear' is supported.
+
+ layout : str
+ The layout of input data and the output.
+
+ Returns
+ -------
+ Output : tvm.Tensor
+ 4-D with shape [batch, 2, out_height, out_width]
+ """
+ return _make.grid_sample(data, grid, method, layout)
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file grid_sample.cc
+ * \brief affine_grid and grid_sample operator
+ */
+#include <tvm/relay/attrs/image.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relay {
+
+// relay.image.affine_grid
+TVM_REGISTER_NODE_TYPE(AffineGridAttrs);
+
+bool AffineGridRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) return false;
+ auto batch_size = data->shape[0];
+
+ const AffineGridAttrs* param = attrs.as<AffineGridAttrs>();
+ CHECK(param != nullptr);
+
+ Array<IndexExpr> oshape;
+
+ CHECK(data->shape.size() == 3U && reporter->AssertEQ(data->shape[1], 2) &&
+ reporter->AssertEQ(data->shape[2], 3))
+ << "data should be an"
+ "affine matrix with shape [batch_size, 2, 3]";
+ CHECK(param->target_shape.defined() && param->target_shape.size() == 2)
+ << "target_shape should be 2D";
+ oshape.push_back(batch_size);
+ oshape.push_back(2);
+ oshape.push_back(param->target_shape[0]);
+ oshape.push_back(param->target_shape[1]);
+
+ // assign output type
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
+ return true;
+}
+
+// Positional relay function to create affine_grid operator
+// used by frontend FFI.
+Expr MakeAffineGrid(Expr data, Array<IndexExpr> target_shape) {
+ auto attrs = make_object<AffineGridAttrs>();
+ attrs->target_shape = std::move(target_shape);
+ static const Op& op = Op::Get("image.affine_grid");
+ return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.image._make.affine_grid").set_body_typed(MakeAffineGrid);
+
+RELAY_REGISTER_OP("image.affine_grid")
+ .describe(R"code(affine_grid operator that generates 2D sampling grid.
+
+This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform
+sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine
+transformation is then applied on the sampling grid.
+
+- **data**: data is 3D array of shape [batch, 2, 3], which defines an affine transformation.
+
+- **out**: out is 4D array of shape [batch, 2, height, width], where each vector
+ :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`
+
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<AffineGridAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The affine matrix.")
+ .set_support_level(5)
+ .add_type_rel("AffineGrid", AffineGridRel)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
+
+// relay.image.grid_sample
+TVM_REGISTER_NODE_TYPE(GridSampleAttrs);
+
+bool GridSampleRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 3);
+ const auto* data = types[0].as<TensorTypeNode>();
+ const auto* grid = types[1].as<TensorTypeNode>();
+ if (!data || !grid) return false;
+ const auto* param = attrs.as<GridSampleAttrs>();
+ CHECK(param);
+ static const Layout kNCHW("NCHW");
+ const Layout in_layout(param->layout);
+ auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
+ auto oshape = layout_converter.ForwardShape(data->shape);
+ oshape.Set(2, grid->shape[2]);
+ oshape.Set(3, grid->shape[3]);
+ // assign output type
+ reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
+ return true;
+}
+
+// Positional relay function to create affine_grid operator
+// used by frontend FFI.
+Expr MakeGridSample(Expr data, Expr grid, String method, String layout) {
+ auto attrs = make_object<GridSampleAttrs>();
+ attrs->method = std::move(method);
+ attrs->layout = std::move(layout);
+ static const Op& op = Op::Get("image.grid_sample");
+ return Call(op, {data, grid}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.image._make.grid_sample").set_body_typed(MakeGridSample);
+
+RELAY_REGISTER_OP("image.grid_sample")
+ .describe(R"code(Applies grid sampling to input feature map.
+
+Given :math:`data` and :math:`grid`, then the output is computed by
+
+.. math::
+ x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
+ y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
+ output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src})
+
+:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
+:math:`G()` denotes the interpolation function.
+The out-boundary points will be padded with zeros. The shape of the output will be
+(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
+
+The operator assumes that :math:`data` has 'NCHW' layout and :math:`grid` has been normalized to [-1, 1].
+
+grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample.
+
+- **data**: data is 4D array of shape
+ (batch_size, channels, in_height, in_width) for NCHW
+ (batch_size, in_height, in_width, channels) for NHWC
+
+- **grid**: out is 4D array of shape [batch, 2, out_height, out_width], where each vector
+ :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`
+
+- **out**: out is 4D array of shape
+ (batch, in_channel, out_height, out_width) for NCHW
+ (batch_size, in_height, in_width, channels) for NHWC
+
+)code" TVM_ADD_FILELINE)
+ .set_num_inputs(2)
+ .set_attrs_type<GridSampleAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(5)
+ .add_type_rel("GridSample", GridSampleRel)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
+
+} // namespace relay
+} // namespace tvm
mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10)
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
+def test_forward_grid_generator():
+ def verify(shape, transform_type, target_shape):
+ x = np.random.uniform(size=shape).astype("float32")
+ ref_res = mx.nd.GridGenerator(mx.nd.array(x), transform_type, target_shape)
+ mx_sym = mx.sym.GridGenerator(mx.sym.var("x"), transform_type, target_shape)
+ shape_dict = {"x": x.shape}
+ mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(
+ kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(x)
+ tvm.testing.assert_allclose(
+ op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+ verify((4, 6), 'affine', (16, 32))
+ verify((4, 2, 16, 16), 'warp', None)
+ verify((1, 2, 16, 16), 'warp', None)
+
+def test_forward_bilinear_sampler():
+ def verify(data_shape, grid_shape):
+ data = np.random.uniform(size=data_shape).astype("float32")
+ grid = np.random.uniform(low=-1.5, high=1.5, size=grid_shape).astype("float32")
+ ref_res = mx.nd.BilinearSampler(mx.nd.array(data), mx.nd.array(grid))
+ mx_sym = mx.sym.BilinearSampler(mx.sym.var("data"), mx.sym.var("grid"))
+ shape_dict = {"data": data.shape, "grid": grid.shape}
+ mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(
+ kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(data, grid)
+ tvm.testing.assert_allclose(
+ op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+ verify((4, 4, 16, 32), (4, 2, 8, 8))
+ verify((4, 4, 16, 32), (4, 2, 32, 32))
+
def test_forward_rnn_layer():
def verify(mode, seq_len, input_size, hidden_size, num_layers,
batch=1, init_states=True, bidirectional=False):
test_forward_unravel_index()
test_forward_swap_axis()
test_forward_correlation()
+ test_forward_grid_generator()
+ test_forward_bilinear_sampler()
data_layout='NHWC', kernel_layout='HWI')
+def test_affine_grid():
+ def verify_affine_grid(num_batch, target_shape):
+ dtype = 'float32'
+ data_shape = (num_batch, 2, 3)
+ data = relay.var("data", relay.ty.TensorType(data_shape, dtype))
+ y = relay.image.affine_grid(data, target_shape)
+ yy = run_infer_type(y)
+ assert yy.checked_type == relay.ty.TensorType((num_batch, len(target_shape), *target_shape), dtype)
+
+ func = relay.Function([data], y)
+ data_np = np.random.uniform(size=data_shape).astype(dtype)
+ ref_res = topi.testing.affine_grid_python(data_np, target_shape)
+
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp1 = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res1 = intrp1.evaluate(func)(data_np)
+ tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
+
+ verify_affine_grid(1, (16, 32))
+ verify_affine_grid(4, (16, 32))
+
+
+def test_grid_sample():
+ def verify_grid_sample(data_shape, grid_shape):
+ dtype = 'float32'
+ batch, channel, _, _ = data_shape
+ _, _, out_height, out_width = grid_shape
+ data = relay.var("data", relay.ty.TensorType(data_shape, dtype))
+ grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype))
+ y = relay.image.grid_sample(data, grid, method='bilinear', layout='NCHW')
+ yy = run_infer_type(y)
+ assert yy.checked_type == relay.TensorType((batch, channel, out_height, out_width), dtype)
+ func = relay.Function([data, grid], y)
+
+ data_np = np.random.uniform(size=data_shape).astype(dtype)
+ grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype)
+ ref_res = topi.testing.grid_sample_nchw_python(data_np, grid_np, method='bilinear')
+
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp1 = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res1 = intrp1.evaluate(func)(data_np, grid_np)
+ tvm.testing.assert_allclose(
+ op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
+
+ verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
+ verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
+
+
if __name__ == "__main__":
test_resize_infer_type()
test_resize()
test_space_to_depth()
test_dilation2d_infer_type()
test_dilation2d_run()
+ test_affine_grid()
+ test_grid_sample()
from .resize import *
from .dilation2d import *
+from .grid_sample import *
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""affine_grid and grid_sample operator"""
+from tvm import te, tir
+
+
+def affine_grid(data, target_shape):
+ """affine_grid operator that generates 2D sampling grid.
+
+ This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform
+ sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine
+ transformation is then applied on the sampling grid.
+
+ Parameters
+ ----------
+ data : tvm.Tensor
+ 3-D with shape [batch, 2, 3]. The affine matrix.
+
+ target_shape: list/tuple of two int
+ Specifies the output shape (H, W).
+
+ Returns
+ -------
+ Output : tvm.Tensor
+ 4-D with shape [batch, 2, target_height, target_width]
+ """
+ assert target_shape is not None
+ assert len(target_shape) == 2
+ assert target_shape[0] > 1 and target_shape[1] > 1, \
+ "target height/width should be greater than 1"
+
+ dtype = data.dtype
+ y_step = tir.const((2.0 - 1e-7)/ (target_shape[0] - 1), dtype=dtype)
+ x_step = tir.const((2.0 - 1e-7)/ (target_shape[1] - 1), dtype=dtype)
+ start = tir.const(-1.0, dtype=dtype)
+
+ def _compute(n, dim, i, j):
+ y = start + i * y_step
+ x = start + j * x_step
+ return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
+
+ oshape = (data.shape[0], len(target_shape), *target_shape)
+ return te.compute(oshape, _compute, tag='affine_grid')
+
+
+def grid_sample(data, grid, method='bilinear', layout='NCHW'):
+ """Applies bilinear sampling to input feature map.
+
+ Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output is computed by
+
+ .. math::
+
+ x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
+ y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
+ output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src})
+
+ :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
+ :math:`G()` denotes the interpolation method.
+ The out-boundary points will be padded with zeros. The shape of the output will be
+ (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
+
+ The operator assumes that :math:`grid` has been normalized to [-1, 1].
+
+ grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample.
+
+ Parameters
+ ----------
+ data : tvm.Tensor
+ 4-D with shape [batch, in_channel, in_height, in_width]
+
+ grid : tvm.Tensor
+ 4-D with shape [batch, 2, out_height, out_width]
+
+ method : str
+ The interpolation method. Only 'bilinear' is supported.
+
+ layout : str
+ The layout of input data and the output.
+
+ Returns
+ -------
+ Output : tvm.Tensor
+ 4-D with shape [batch, 2, out_height, out_width]
+ """
+ batch, in_channel, in_height, in_width = data.shape
+ out_height, out_width = grid.shape[2:]
+ assert method == 'bilinear', "Only bilinear is supported"
+ assert layout == "NCHW", "Only NCHW is supported"
+
+ def _get_pixel_value(n, c, h, w):
+ return te.if_then_else(te.all(h >= 0, w >= 0, h < in_height, w < in_width),
+ data[n, c, h, w], tir.const(0.0, dtype=data.dtype))
+
+ def _bilinear_sample(n, c, h, w):
+ x = grid[n, 0, h, w]
+ y = grid[n, 1, h, w]
+ y = (y + 1) * (in_height - 1) / 2
+ x = (x + 1) * (in_width - 1) / 2
+ x0 = te.floor(x).astype('int32')
+ y0 = te.floor(y).astype('int32')
+ x1 = x0 + tir.const(1, 'int32')
+ y1 = y0 + tir.const(1, 'int32')
+ return _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) \
+ + _get_pixel_value(n, c, y0, x1) * (1.0 - (y - y0)) * (x - x0) \
+ + _get_pixel_value(n, c, y1, x0) * (y - y0) * (1.0 - (x - x0)) \
+ + _get_pixel_value(n, c, y1, x1) * (y - y0) * (x - x0)
+
+ return te.compute((batch, in_channel, out_height, out_width), _bilinear_sample,
+ tag='grid_sample')
from .common import get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, \
get_elemwise_schedule, get_conv2d_nchw_implement, dispatch
from .adaptive_pool_python import adaptive_pool
+from .grid_sample_python import affine_grid_python, grid_sample_nchw_python
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""affine_grid and grid_sample operators in python"""
+import math
+import numpy as np
+
+
+def affine_grid_python(data, target_shape):
+ yv, xv = np.meshgrid(
+ np.arange(target_shape[0]), np.arange(target_shape[1]))
+ yv = yv.T * 2 / (target_shape[0] - 1) - 1
+ xv = xv.T * 2 / (target_shape[1] - 1) - 1
+ ones = np.ones_like(xv)
+ grid = np.stack([xv, yv, ones]).reshape(3, -1)
+ return data.reshape(-1, 3).dot(grid).reshape(data.shape[0], 2, *target_shape)
+
+
+def _bilinear_sample_nchw_python(data, grid):
+ batch, in_channel, in_height, in_width = data.shape
+ _, _, out_height, out_width = grid.shape
+ out = np.zeros((batch, in_channel, out_height, out_width), dtype=data.dtype)
+
+ def _within_bound(y, x):
+ return 0 <= y < in_height and 0 <= x < in_width
+
+ for n in range(0, batch):
+ for h in range(0, out_height):
+ for w in range(0, out_width):
+ x, y = grid[n, :, h, w]
+ y = (y + 1) * (in_height - 1) / 2
+ x = (x + 1) * (in_width - 1) / 2
+ y0 = int(math.floor(y))
+ x0 = int(math.floor(x))
+ y1 = y0 + 1
+ x1 = x0 + 1
+ if _within_bound(y0, x0):
+ out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0))
+ if _within_bound(y0, x1):
+ out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0)
+ if _within_bound(y1, x0):
+ out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0))
+ if _within_bound(y1, x1):
+ out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0)
+ return out
+
+
+def grid_sample_nchw_python(data, grid, method='bilinear'):
+ if method == 'bilinear':
+ return _bilinear_sample_nchw_python(data, grid)
+ raise ValueError("invalid method")
from tvm import te
import topi
import topi.testing
+from tvm.contrib.pickle_memoize import memoize
from common import get_all_backend
size_1, method='nearest_neighbor')
verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW")
+
+def test_affine_grid():
+ def verify_affine_grid(num_batch, target_shape):
+ dtype = "float32"
+ data_shape = (num_batch, 2, 3)
+ data = te.placeholder(data_shape, dtype=dtype)
+ out = topi.image.affine_grid(data, target_shape)
+
+ @memoize("topi.tests.test_affine_grid.verify_affine_grid")
+ def get_ref_data():
+ data_np = np.random.uniform(size=data_shape).astype(dtype)
+ out_np = topi.testing.affine_grid_python(data_np, target_shape)
+ return data_np, out_np
+
+ data_np, out_np = get_ref_data()
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.testing.get_injective_schedule(device)(out)
+ tvm_data = tvm.nd.array(data_np, ctx)
+ tvm_out = tvm.nd.empty(out_np.shape, dtype, ctx)
+ f = tvm.build(s, [data, out], device)
+ f(tvm_data, tvm_out)
+
+ tvm.testing.assert_allclose(
+ tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5)
+
+ for device in get_all_backend():
+ check_device(device)
+
+ verify_affine_grid(1, (16, 32))
+ verify_affine_grid(4, (16, 32))
+
+
+def test_grid_sample():
+ def verify_grid_sample(data_shape, grid_shape):
+ dtype = "float32"
+ data = te.placeholder(data_shape, dtype=dtype)
+ grid = te.placeholder(grid_shape, dtype=dtype)
+ out = topi.image.grid_sample(data, grid, 'bilinear')
+
+ @memoize("topi.tests.test_grid_sample.verify_grid_sample")
+ def get_ref_data():
+ data_np = np.random.uniform(size=data_shape).astype(dtype)
+ # allow grid values to be out-of-bound
+ grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype)
+ out_np = topi.testing.grid_sample_nchw_python(data_np, grid_np, 'bilinear')
+ return data_np, grid_np, out_np
+
+ data_np, grid_np, out_np = get_ref_data()
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.testing.get_injective_schedule(device)(out)
+ tvm_data = tvm.nd.array(data_np, ctx)
+ tvm_grid = tvm.nd.array(grid_np, ctx)
+ tvm_out = tvm.nd.empty(out_np.shape, dtype, ctx)
+ f = tvm.build(s, [data, grid, out], device)
+ f(tvm_data, tvm_grid, tvm_out)
+
+ tvm.testing.assert_allclose(
+ tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5)
+
+ for device in get_all_backend():
+ check_device(device)
+
+ verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
+ verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
+
+
if __name__ == "__main__":
test_resize()
test_resize3d()
test_crop_and_resize()
+ test_affine_grid()
+ test_grid_sample()