From 9ad33feef32783f2250fae8473eb1f6ca2957112 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 7 Aug 2020 17:08:52 -0700 Subject: [PATCH] [Relay][Dynamic] Add Dynamic Resize Op (#6198) * WIP * optionally remove output shape inference from topi * fix resize * add resize to dynamic_to_static pass add resize to dynamic_to_static pass * fix clang-format * fix bad rebase * add argument to dynamic resize doc string * fix i386 test * fix lint --- python/tvm/relay/op/dyn/__init__.py | 2 + python/tvm/relay/op/dyn/image/__init__.py | 20 ++++ python/tvm/relay/op/dyn/image/_image.py | 76 +++++++++++++++ python/tvm/relay/op/dyn/image/_make.py | 20 ++++ python/tvm/relay/op/image/image.py | 17 +++- python/tvm/topi/image/resize.py | 18 ++-- src/relay/op/dyn/image/resize.cc | 109 ++++++++++++++++++++++ src/relay/op/image/resize.cc | 1 + src/relay/op/make_op.h | 3 + src/relay/transforms/dynamic_to_static.cc | 16 ++++ tests/python/relay/dyn/test_dynamic_op_level5.py | 69 ++++++++++++++ tests/python/relay/test_pass_dynamic_to_static.py | 86 +++++++++++++---- 12 files changed, 410 insertions(+), 27 deletions(-) create mode 100644 python/tvm/relay/op/dyn/image/__init__.py create mode 100644 python/tvm/relay/op/dyn/image/_image.py create mode 100644 python/tvm/relay/op/dyn/image/_make.py create mode 100644 src/relay/op/dyn/image/resize.cc create mode 100644 tests/python/relay/dyn/test_dynamic_op_level5.py diff --git a/python/tvm/relay/op/dyn/__init__.py b/python/tvm/relay/op/dyn/__init__.py index 967ecbc..c6dbca3 100644 --- a/python/tvm/relay/op/dyn/__init__.py +++ b/python/tvm/relay/op/dyn/__init__.py @@ -20,3 +20,5 @@ from . import _algorithm from . import _transform from . import _tensor + +from .import image diff --git a/python/tvm/relay/op/dyn/image/__init__.py b/python/tvm/relay/op/dyn/image/__init__.py new file mode 100644 index 0000000..270421a --- /dev/null +++ b/python/tvm/relay/op/dyn/image/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay namespace containing dynamic image ops.""" + +from . import _image diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py new file mode 100644 index 0000000..fa528e9 --- /dev/null +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=invalid-name, unused-argument +"""Backend compiler related feature registration""" +from __future__ import absolute_import + +import tvm.topi +from tvm.runtime import convert +from tvm.te.hybrid import script +from tvm.topi.util import nchw_pack_layout, nchw_xc_layout +from ... import op as reg + + +# resize +@reg.register_compute("dyn.image.resize") +def compute_resize(attrs, inputs, out_type): + layout = attrs.layout + method = attrs.method + coord_trans = attrs.coordinate_transformation_mode + out_dtype = attrs.out_dtype + return [ + tvm.topi.image.resize(inputs[0], inputs[1], layout, method, coord_trans, out_dtype, + out_type.shape) + ] + + +reg.register_injective_schedule("dyn.image.resize") + + +@script +def _NCHW_resize_shape_func(dshape, size, ndim): + out = output_tensor((ndim, ), "int64") + for i in const_range(ndim): + out[i] = int64(dshape[i]) + out[2] = int64(size[0]) + out[3] = int64(size[1]) + return out + + +@script +def _NHWC_resize_shape_func(dshape, size, ndim): + out = output_tensor((ndim, ), "int64") + for i in const_range(ndim): + out[i] = int64(dshape[i]) + out[1] = int64(size[0]) + out[2] = int64(size[1]) + return out + + +@reg.register_shape_func("dyn.image.resize", True) +def resize_shape_func(attrs, inputs, _): + """ + Shape function for dyn.image.resize op. + """ + layout = attrs.layout + if layout == 'NHWC': + out = [_NHWC_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))] + elif (layout == 'NCHW') or nchw_pack_layout(layout) or nchw_xc_layout(layout): + out = [_NCHW_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))] + else: + raise ValueError("Resize Unsupported Layout", layout) + return out diff --git a/python/tvm/relay/op/dyn/image/_make.py b/python/tvm/relay/op/dyn/image/_make.py new file mode 100644 index 0000000..69830ae --- /dev/null +++ b/python/tvm/relay/op/dyn/image/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relay.op.dyn.image._make", __name__) diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 99a6a0f..607e1d3 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -16,6 +16,9 @@ # under the License. """Image operations.""" from . import _make +from ..dyn.image import _make as _dyn_make +from ...expr import Expr + def resize(data, size, @@ -38,7 +41,7 @@ def resize(data, data : relay.Expr The input data to the operator. - size: Tuple of Expr + size: Tuple of Int or Expr The out size to which the image will be resized. layout : str, optional @@ -61,6 +64,9 @@ def resize(data, result: relay.Expr The resized result. """ + if isinstance(size, Expr): + return _dyn_make.resize(data, size, layout, method, coordinate_transformation_mode, + out_dtype) return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype) @@ -156,8 +162,8 @@ def crop_and_resize(data, result: relay.Expr The computed result. """ - return _make.crop_and_resize(data, boxes, box_indices, crop_size, - layout, method, extrapolation_value, out_dtype) + return _make.crop_and_resize(data, boxes, box_indices, crop_size, layout, method, + extrapolation_value, out_dtype) def dilation2d(data, @@ -213,8 +219,8 @@ def dilation2d(data, The computed result. """ - return _make.dilation2d(data, weight, strides, padding, dilations, data_layout, - kernel_layout, out_dtype) + return _make.dilation2d(data, weight, strides, padding, dilations, data_layout, kernel_layout, + out_dtype) def affine_grid(data, target_shape=None): @@ -239,6 +245,7 @@ def affine_grid(data, target_shape=None): """ return _make.affine_grid(data, target_shape) + def grid_sample(data, grid, method='bilinear', layout='NCHW'): """Applies bilinear sampling to input feature map. diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index d6c0845..aab3009 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -491,7 +491,7 @@ def resize_bicubic(indices, data, image_height, image_width, def resize(data, size, layout="NCHW", method="bilinear", - coordinate_transformation_mode="half_pixel", out_dtype=None): + coordinate_transformation_mode="half_pixel", out_dtype=None, output_shape=None): """Perform resize operation on the data. Parameters @@ -519,6 +519,9 @@ def resize(data, size, layout="NCHW", method="bilinear", out_dtype: string, optional Type to return. If left None will be same as input type. + output_shape: optional + Shape to return. If left None will be inferred + Returns ------- output : tvm.te.Tensor @@ -528,19 +531,22 @@ def resize(data, size, layout="NCHW", method="bilinear", """ method = method.lower() - if layout == 'NHWC': in_n, in_h, in_w, in_c = data.shape - output_shape = [in_n, size[0], size[1], in_c] + if output_shape is None: + output_shape = [in_n, size[0], size[1], in_c] elif layout == 'NCHW': in_n, in_c, in_h, in_w = data.shape - output_shape = [in_n, in_c, size[0], size[1]] + if output_shape is None: + output_shape = [in_n, in_c, size[0], size[1]] elif nchw_pack_layout(layout):# for NCHWinic in_n, in_c, in_h, in_w, in_inum, in_ic = data.shape - output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic] + if output_shape is None: + output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic] elif nchw_xc_layout(layout):# for NCHWxc in_n, in_c, in_h, in_w, in_cc = data.shape - output_shape = [in_n, in_c, size[0], size[1], in_cc] + if output_shape is None: + output_shape = [in_n, in_c, size[0], size[1], in_cc] else: raise ValueError('%s layout is not supported.' % layout) diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc new file mode 100644 index 0000000..23e1740 --- /dev/null +++ b/src/relay/op/dyn/image/resize.cc @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file resize.cc + * \brief Image resize operators + */ +#include +#include +#include + +#include "../../op_common.h" + +namespace tvm { +namespace relay { +namespace dyn { + +TVM_REGISTER_NODE_TYPE(ResizeAttrs); + +bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // {data, size, out} + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCHW("NCHW"); + + const ResizeAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->layout); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); + CHECK(layout_converter.defined()) + << "Resize only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, Any()); + oshape.Set(3, Any()); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + // assign output type + reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); + return true; +} + +// Positional relay function to create image operator +// used by frontend FFI. +Expr MakeResize(Expr data, Expr size, String layout, String method, + String coordinate_transformation_mode, DataType out_dtype) { + auto attrs = make_object(); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("dyn.image.resize"); + return Call(op, {data, size}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize").set_body_typed(MakeResize); + +RELAY_REGISTER_OP("dyn.image.resize") + .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 4D array of shape + (batch_size, channels, in_height, in_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +- **size**: data is 2D array of shape (2,) with values + (new_height, new_width) + +- **out**: Output is 4D array of shape + for layout NCHW + (batch_size, channels, size[0], size[1]) + + for layout NHWC + (batch_size, size[0], size[1], channels) +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("size", "Tensor", "The output size tensor.") + .set_support_level(5) + .add_type_rel("DynResize", ResizeRel) + .set_attr("TOpPattern", kInjective); + +} // namespace dyn +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b6d2c71..41b7afe 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -25,6 +25,7 @@ #include #include +#include "../make_op.h" #include "../op_common.h" namespace tvm { diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index d2c170d..c03a7bf 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -80,6 +80,9 @@ Expr MakeZeros(Array shape, DataType dtype); Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); +Expr MakeResize(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, DataType out_dtype); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_MAKE_OP_H_ diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 8501ee5..d0a6b07 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -23,6 +23,7 @@ * \brief Rewrite Dynamic Operations to Static operations where possible */ #include +#include #include #include @@ -98,6 +99,21 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("dyn.image.resize"), + [](const CallNode* call_node) { + if (const ConstantNode* size = call_node->args[1].as()) { + const ResizeAttrs* param = call_node->attrs.as(); + CHECK(param); + auto size_int = ToVector(size->data); + Array size_prim; + for (size_t i = 0; i < size_int.size(); ++i) { + size_prim.push_back(size_int[i]); + } + return MakeResize(call_node->args[0], size_prim, param->layout, param->method, + param->coordinate_transformation_mode, param->out_dtype); + } + return Expr(nullptr); + }}, }; } diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py new file mode 100644 index 0000000..f0095a1 --- /dev/null +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Support level5 operator test cases. +""" +import math +import numpy as np +import tvm +from tvm import te +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import ctx_list, run_infer_type +import tvm.topi.testing + + +def test_resize_infer_type(): + n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) + size = relay.var("size", relay.TensorType((2,), "int8")) + z = relay.image.resize(x, size) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") + + +def test_resize(): + def verify_resize(dshape, scale, method, layout): + if layout == "NHWC": + size = (dshape[1] * scale, dshape[2] * scale) + else: + size = (dshape[2] * scale, dshape[3] * scale) + size = np.array(size).astype("int64") + x_data = np.random.uniform(size=dshape).astype("float32") + if method == "bilinear": + ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) + else: + ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) + x = relay.var("x", relay.TensorType(dshape, "float32")) + size_var = relay.var("size", relay.TensorType((2,), "int64")) + z = relay.image.resize(x, size_var, layout, method, "align_corners") + zz = run_infer_type(z) + func = relay.Function([x, size_var], z) + + for target, ctx in ctx_list(): + if "llvm" not in target: continue + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data, size) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) + for method in ["bilinear", "nearest_neighbor"]: + for layout in ["NCHW", "NHWC"]: + verify_resize((1, 4, 4, 4), 2, method, layout) + +if __name__ == "__main__": + test_resize_infer_type() + test_resize() diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index a50c9df..5342f2d 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -21,7 +21,6 @@ from tvm import relay from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.testing import run_infer_type, create_workload, ctx_list - import tvm.topi.testing @@ -33,14 +32,16 @@ def run_opt_pass(expr, opt_pass): entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body -def verify_func(func, data, ref_res): + +def verify_func(func, data, ref_res, rtol=1e-5, atol=1e-7): assert isinstance(data, list) for target, ctx in ctx_list(): for kind in ["graph", "vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(*data) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) + def test_dynamic_to_static_reshape(): def verify_reshape(shape, newshape, oshape): @@ -48,7 +49,8 @@ def test_dynamic_to_static_reshape(): y = relay.var("y", relay.TensorType(newshape, "float32")) z = relay.reshape(x, relay.shape_of(y)) func = run_infer_type(relay.Function([x, y], z)) - func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) zz = func2.body assert isinstance(zz, relay.Call) @@ -64,6 +66,7 @@ def test_dynamic_to_static_reshape(): verify_reshape((2, 3, 4), (8, 3), (8, 3)) verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) + def test_dynamic_to_static_double_reshape(): def verify_reshape(shape, newshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -71,7 +74,8 @@ def test_dynamic_to_static_double_reshape(): z = relay.reshape(x, relay.shape_of(y)) z = relay.reshape(z, relay.shape_of(x)) func = run_infer_type(relay.Function([x, y], z)) - func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) zz = func2.body assert isinstance(zz, relay.Call) @@ -86,6 +90,7 @@ def test_dynamic_to_static_double_reshape(): verify_reshape((2, 3, 4), (8, 3)) verify_reshape((4, 7), (2, 7, 2)) + def test_dynamic_to_static_quad_reshape(): def verify_reshape(shape, newshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -95,7 +100,8 @@ def test_dynamic_to_static_quad_reshape(): z3 = relay.reshape(z2, relay.shape_of(z1)) z4 = relay.reshape(z3, relay.shape_of(z2)) func = run_infer_type(relay.Function([x, y], z4)) - func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) zz = func2.body assert isinstance(zz, relay.Call) @@ -110,13 +116,15 @@ def test_dynamic_to_static_quad_reshape(): verify_reshape((2, 3, 4), (8, 3)) verify_reshape((4, 7), (2, 7, 2)) + def test_dynamic_to_static_tile(): def verify_tile(shape, reps, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) y = relay.var("y", relay.TensorType(reps, "float32")) z = relay.tile(x, relay.shape_of(y)) func = run_infer_type(relay.Function([x, y], z)) - func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) zz = func2.body assert isinstance(zz, relay.Call) @@ -131,6 +139,7 @@ def test_dynamic_to_static_tile(): verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20)) verify_tile((4, 7), (4, 2), (16, 14)) + def test_dynamic_to_static_topk(): def verify_topk(k, axis, ret_type, is_ascend, dtype): shape = (20, 100) @@ -159,7 +168,8 @@ def test_dynamic_to_static_topk(): np_values[i, :] = np_data[i, np_indices[i, :]] np_indices = np_indices.astype(dtype) - func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) zz = func2.body assert isinstance(zz, relay.Call) assert zz.op == relay.op.get("topk") @@ -177,12 +187,15 @@ def test_dynamic_to_static_topk(): tvm.testing.assert_allclose(op_res.asnumpy(), np_values) else: tvm.testing.assert_allclose(op_res.asnumpy(), np_indices) + np.random.seed(0) for k in [0, 1, 5]: for axis in [0, -1, 1]: for ret_type in ["both", "values", "indices"]: verify_topk(k, axis, ret_type, True, "int64") verify_topk(k, axis, ret_type, False, "float32") + + def test_dynamic_to_static_broadcast_to(): def verify_broadcast_to(shape, broadcast_shape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -190,28 +203,32 @@ def test_dynamic_to_static_broadcast_to(): z = relay.broadcast_to(x, shape=relay.shape_of(y)) func = run_infer_type(relay.Function([x, y], z)) - func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) zz = func2.body assert isinstance(zz, relay.Call) assert zz.op == relay.op.get("broadcast_to") assert zz.checked_type == relay.ty.TensorType(broadcast_shape, "float32") - + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") y_data = np.random.uniform(low=-1, high=1, size=broadcast_shape).astype("float32") - + ref_res = np.broadcast_to(x_data, y_data.shape) verify_func(func2, [x_data, y_data], ref_res) + verify_broadcast_to((3, 1), (3, 3)) - + + def test_dynamic_to_static_zeros_ones(): def verify_ones_zeros(shape, dtype): for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]: x = relay.var("x", relay.TensorType(shape, dtype)) y = op(relay.shape_of(x), dtype) - + func = run_infer_type(relay.Function([x], y)) - func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) zz = func2.body assert isinstance(zz, relay.Constant) @@ -224,6 +241,39 @@ def test_dynamic_to_static_zeros_ones(): verify_ones_zeros((1, 2, 3), 'int64') verify_ones_zeros((9, 8, 3, 4), 'float32') + +def test_dynamic_to_static_resize(): + def verify_resize(shape, scale, method, layout): + if layout == "NHWC": + size = (shape[1] * scale, shape[2] * scale) + else: + size = (shape[2] * scale, shape[3] * scale) + + x = relay.var("x", relay.TensorType(shape, "float32")) + size_var = relay.const(np.array(size).astype("float32")) + z = relay.image.resize(x, size_var, layout, method, "align_corners") + + func = run_infer_type(relay.Function([x], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("image.resize") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + + if method == "bilinear": + ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) + else: + ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) + verify_func(func2, [x_data], ref_res, rtol=1e-4, atol=1e-6) + + for method in ["bilinear", "nearest_neighbor"]: + for layout in ["NCHW", "NHWC"]: + verify_resize((1, 4, 4, 4), 2, method, layout) + + def test_dynamic_to_static_one_hot(): def _verify(indices_shape, depth, on_value, off_value, axis, dtype): indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) @@ -233,7 +283,8 @@ def test_dynamic_to_static_one_hot(): out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype) func = relay.Function([indices], out) - func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType()) zz = func2.body assert isinstance(zz, relay.Call) @@ -250,7 +301,8 @@ def test_dynamic_to_static_one_hot(): _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") -if __name__=="__main__": + +if __name__ == "__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() test_dynamic_to_static_quad_reshape() @@ -258,3 +310,5 @@ if __name__=="__main__": test_dynamic_to_static_topk() test_dynamic_to_static_broadcast_to() test_dynamic_to_static_zeros_ones() + test_dynamic_to_static_resize() + test_dynamic_to_static_one_hot() -- 2.7.4