[Relay][Dynamic] Add Dynamic Resize Op (#6198)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Sat, 8 Aug 2020 00:08:52 +0000 (17:08 -0700)
committerGitHub <noreply@github.com>
Sat, 8 Aug 2020 00:08:52 +0000 (17:08 -0700)
* WIP

* optionally remove output shape inference from topi

* fix resize

* add resize to dynamic_to_static pass

add resize to dynamic_to_static pass

* fix clang-format

* fix bad rebase

* add argument to dynamic resize doc string

* fix i386 test

* fix lint

12 files changed:
python/tvm/relay/op/dyn/__init__.py
python/tvm/relay/op/dyn/image/__init__.py [new file with mode: 0644]
python/tvm/relay/op/dyn/image/_image.py [new file with mode: 0644]
python/tvm/relay/op/dyn/image/_make.py [new file with mode: 0644]
python/tvm/relay/op/image/image.py
python/tvm/topi/image/resize.py
src/relay/op/dyn/image/resize.cc [new file with mode: 0644]
src/relay/op/image/resize.cc
src/relay/op/make_op.h
src/relay/transforms/dynamic_to_static.cc
tests/python/relay/dyn/test_dynamic_op_level5.py [new file with mode: 0644]
tests/python/relay/test_pass_dynamic_to_static.py

index 967ecbc..c6dbca3 100644 (file)
@@ -20,3 +20,5 @@
 from . import _algorithm
 from . import _transform
 from . import _tensor
+
+from .import image
diff --git a/python/tvm/relay/op/dyn/image/__init__.py b/python/tvm/relay/op/dyn/image/__init__.py
new file mode 100644 (file)
index 0000000..270421a
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import, redefined-builtin, invalid-name
+"""The Relay namespace containing dynamic image ops."""
+
+from . import _image
diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py
new file mode 100644 (file)
index 0000000..fa528e9
--- /dev/null
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=invalid-name, unused-argument
+"""Backend compiler related feature registration"""
+from __future__ import absolute_import
+
+import tvm.topi
+from tvm.runtime import convert
+from tvm.te.hybrid import script
+from tvm.topi.util import nchw_pack_layout, nchw_xc_layout
+from ... import op as reg
+
+
+# resize
+@reg.register_compute("dyn.image.resize")
+def compute_resize(attrs, inputs, out_type):
+    layout = attrs.layout
+    method = attrs.method
+    coord_trans = attrs.coordinate_transformation_mode
+    out_dtype = attrs.out_dtype
+    return [
+        tvm.topi.image.resize(inputs[0], inputs[1], layout, method, coord_trans, out_dtype,
+                              out_type.shape)
+    ]
+
+
+reg.register_injective_schedule("dyn.image.resize")
+
+
+@script
+def _NCHW_resize_shape_func(dshape, size, ndim):
+    out = output_tensor((ndim, ), "int64")
+    for i in const_range(ndim):
+        out[i] = int64(dshape[i])
+    out[2] = int64(size[0])
+    out[3] = int64(size[1])
+    return out
+
+
+@script
+def _NHWC_resize_shape_func(dshape, size, ndim):
+    out = output_tensor((ndim, ), "int64")
+    for i in const_range(ndim):
+        out[i] = int64(dshape[i])
+    out[1] = int64(size[0])
+    out[2] = int64(size[1])
+    return out
+
+
+@reg.register_shape_func("dyn.image.resize", True)
+def resize_shape_func(attrs, inputs, _):
+    """
+    Shape function for dyn.image.resize op.
+    """
+    layout = attrs.layout
+    if layout == 'NHWC':
+        out = [_NHWC_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))]
+    elif (layout == 'NCHW') or nchw_pack_layout(layout) or nchw_xc_layout(layout):
+        out = [_NCHW_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))]
+    else:
+        raise ValueError("Resize Unsupported Layout", layout)
+    return out
diff --git a/python/tvm/relay/op/dyn/image/_make.py b/python/tvm/relay/op/dyn/image/_make.py
new file mode 100644 (file)
index 0000000..69830ae
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Constructor APIs"""
+import tvm._ffi
+
+tvm._ffi._init_api("relay.op.dyn.image._make", __name__)
index 99a6a0f..607e1d3 100644 (file)
@@ -16,6 +16,9 @@
 # under the License.
 """Image operations."""
 from . import _make
+from ..dyn.image import _make as _dyn_make
+from ...expr import Expr
+
 
 def resize(data,
            size,
@@ -38,7 +41,7 @@ def resize(data,
     data : relay.Expr
         The input data to the operator.
 
-    size: Tuple of Expr
+    size: Tuple of Int or Expr
         The out size to which the image will be resized.
 
     layout : str, optional
@@ -61,6 +64,9 @@ def resize(data,
     result: relay.Expr
         The resized result.
     """
+    if isinstance(size, Expr):
+        return _dyn_make.resize(data, size, layout, method, coordinate_transformation_mode,
+                                out_dtype)
     return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)
 
 
@@ -156,8 +162,8 @@ def crop_and_resize(data,
     result: relay.Expr
         The computed result.
     """
-    return _make.crop_and_resize(data, boxes, box_indices, crop_size,
-                                 layout, method, extrapolation_value, out_dtype)
+    return _make.crop_and_resize(data, boxes, box_indices, crop_size, layout, method,
+                                 extrapolation_value, out_dtype)
 
 
 def dilation2d(data,
@@ -213,8 +219,8 @@ def dilation2d(data,
         The computed result.
     """
 
-    return _make.dilation2d(data, weight, strides, padding, dilations, data_layout,
-                            kernel_layout, out_dtype)
+    return _make.dilation2d(data, weight, strides, padding, dilations, data_layout, kernel_layout,
+                            out_dtype)
 
 
 def affine_grid(data, target_shape=None):
@@ -239,6 +245,7 @@ def affine_grid(data, target_shape=None):
     """
     return _make.affine_grid(data, target_shape)
 
+
 def grid_sample(data, grid, method='bilinear', layout='NCHW'):
     """Applies bilinear sampling to input feature map.
 
index d6c0845..aab3009 100644 (file)
@@ -491,7 +491,7 @@ def resize_bicubic(indices, data, image_height, image_width,
 
 
 def resize(data, size, layout="NCHW", method="bilinear",
-           coordinate_transformation_mode="half_pixel", out_dtype=None):
+           coordinate_transformation_mode="half_pixel", out_dtype=None, output_shape=None):
     """Perform resize operation on the data.
 
     Parameters
@@ -519,6 +519,9 @@ def resize(data, size, layout="NCHW", method="bilinear",
     out_dtype: string, optional
         Type to return. If left None will be same as input type.
 
+    output_shape: optional
+        Shape to return. If left None will be inferred
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -528,19 +531,22 @@ def resize(data, size, layout="NCHW", method="bilinear",
     """
 
     method = method.lower()
-
     if layout == 'NHWC':
         in_n, in_h, in_w, in_c = data.shape
-        output_shape = [in_n, size[0], size[1], in_c]
+        if output_shape is None:
+            output_shape = [in_n, size[0], size[1], in_c]
     elif layout == 'NCHW':
         in_n, in_c, in_h, in_w = data.shape
-        output_shape = [in_n, in_c, size[0], size[1]]
+        if output_shape is None:
+            output_shape = [in_n, in_c, size[0], size[1]]
     elif nchw_pack_layout(layout):# for NCHWinic
         in_n, in_c, in_h, in_w, in_inum, in_ic = data.shape
-        output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic]
+        if output_shape is None:
+            output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic]
     elif nchw_xc_layout(layout):# for NCHWxc
         in_n, in_c, in_h, in_w, in_cc = data.shape
-        output_shape = [in_n, in_c, size[0], size[1], in_cc]
+        if output_shape is None:
+            output_shape = [in_n, in_c, size[0], size[1], in_cc]
     else:
         raise ValueError('%s layout is not supported.' % layout)
 
diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc
new file mode 100644 (file)
index 0000000..23e1740
--- /dev/null
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file resize.cc
+ * \brief Image resize operators
+ */
+#include <tvm/relay/attrs/image.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
+#include "../../op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace dyn {
+
+TVM_REGISTER_NODE_TYPE(ResizeAttrs);
+
+bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+               const TypeReporter& reporter) {
+  // {data, size, out}
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  static const Layout kNCHW("NCHW");
+
+  const ResizeAttrs* param = attrs.as<ResizeAttrs>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->layout);
+  auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
+  CHECK(layout_converter.defined())
+      << "Resize only support input layouts that are convertible from NCHW."
+      << " But got " << in_layout;
+
+  auto oshape = layout_converter.ForwardShape(data->shape);
+  oshape.Set(2, Any());
+  oshape.Set(3, Any());
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+
+  // assign output type
+  reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), out_dtype));
+  return true;
+}
+
+// Positional relay function to create image operator
+// used by frontend FFI.
+Expr MakeResize(Expr data, Expr size, String layout, String method,
+                String coordinate_transformation_mode, DataType out_dtype) {
+  auto attrs = make_object<ResizeAttrs>();
+  attrs->layout = std::move(layout);
+  attrs->method = std::move(method);
+  attrs->coordinate_transformation_mode = coordinate_transformation_mode;
+  attrs->out_dtype = out_dtype;
+  static const Op& op = Op::Get("dyn.image.resize");
+  return Call(op, {data, size}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize").set_body_typed(MakeResize);
+
+RELAY_REGISTER_OP("dyn.image.resize")
+    .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation.
+
+- **data**: data is 4D array of shape
+            (batch_size, channels, in_height, in_width) for NCHW
+            (batch_size, in_height, in_width, channels) for NHWC
+
+- **size**: data is 2D array of shape (2,) with values
+            (new_height, new_width)
+
+- **out**: Output is 4D array of shape
+           for layout NCHW
+           (batch_size, channels, size[0], size[1])
+
+           for layout NHWC
+           (batch_size, size[0], size[1], channels)
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<ResizeAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("size", "Tensor", "The output size tensor.")
+    .set_support_level(5)
+    .add_type_rel("DynResize", ResizeRel)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
+}  // namespace dyn
+}  // namespace relay
+}  // namespace tvm
index b6d2c71..41b7afe 100644 (file)
@@ -25,6 +25,7 @@
 #include <tvm/relay/op.h>
 #include <tvm/tir/data_layout.h>
 
+#include "../make_op.h"
 #include "../op_common.h"
 
 namespace tvm {
index d2c170d..c03a7bf 100644 (file)
@@ -80,6 +80,9 @@ Expr MakeZeros(Array<Integer> shape, DataType dtype);
 
 Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype);
 
+Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,
+                String coordinate_transformation_mode, DataType out_dtype);
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_OP_MAKE_OP_H_
index 8501ee5..d0a6b07 100644 (file)
@@ -23,6 +23,7 @@
  * \brief Rewrite Dynamic Operations to Static operations where possible
  */
 #include <tvm/relay/attrs/algorithm.h>
+#include <tvm/relay/attrs/image.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
 
@@ -98,6 +99,21 @@ class DynamicToStaticMutator : public MixedModeMutator {
            }
            return Expr(nullptr);
          }},
+        {Op::Get("dyn.image.resize"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* size = call_node->args[1].as<ConstantNode>()) {
+             const ResizeAttrs* param = call_node->attrs.as<ResizeAttrs>();
+             CHECK(param);
+             auto size_int = ToVector(size->data);
+             Array<PrimExpr> size_prim;
+             for (size_t i = 0; i < size_int.size(); ++i) {
+               size_prim.push_back(size_int[i]);
+             }
+             return MakeResize(call_node->args[0], size_prim, param->layout, param->method,
+                               param->coordinate_transformation_mode, param->out_dtype);
+           }
+           return Expr(nullptr);
+         }},
     };
   }
 
diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py
new file mode 100644 (file)
index 0000000..f0095a1
--- /dev/null
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Support level5 operator test cases.
+"""
+import math
+import numpy as np
+import tvm
+from tvm import te
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.testing import ctx_list, run_infer_type
+import tvm.topi.testing
+
+
+def test_resize_infer_type():
+    n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+    x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
+    size = relay.var("size", relay.TensorType((2,), "int8"))
+    z = relay.image.resize(x, size)
+    zz = run_infer_type(z)
+    assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8")
+
+
+def test_resize():
+    def verify_resize(dshape, scale, method, layout):
+        if layout == "NHWC":
+            size = (dshape[1] * scale, dshape[2] * scale)
+        else:
+            size = (dshape[2] * scale, dshape[3] * scale)
+        size = np.array(size).astype("int64")
+        x_data = np.random.uniform(size=dshape).astype("float32")
+        if method == "bilinear":
+            ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout)
+        else:
+            ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout)
+        x = relay.var("x", relay.TensorType(dshape, "float32"))
+        size_var = relay.var("size", relay.TensorType((2,), "int64"))
+        z = relay.image.resize(x, size_var, layout, method, "align_corners")
+        zz = run_infer_type(z)
+        func = relay.Function([x, size_var], z)
+
+        for target, ctx in ctx_list():
+            if "llvm" not in target: continue
+            for kind in ["vm", "debug"]:
+                mod = tvm.ir.IRModule.from_expr(func)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(x_data, size)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
+    for method in ["bilinear", "nearest_neighbor"]:
+        for layout in ["NCHW", "NHWC"]:
+            verify_resize((1, 4, 4, 4), 2, method, layout)
+
+if __name__ == "__main__":
+    test_resize_infer_type()
+    test_resize()
index a50c9df..5342f2d 100644 (file)
@@ -21,7 +21,6 @@ from tvm import relay
 from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
 from tvm.relay.testing import run_infer_type, create_workload, ctx_list
-
 import tvm.topi.testing
 
 
@@ -33,14 +32,16 @@ def run_opt_pass(expr, opt_pass):
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
-def verify_func(func, data, ref_res):
+
+def verify_func(func, data, ref_res, rtol=1e-5, atol=1e-7):
     assert isinstance(data, list)
     for target, ctx in ctx_list():
         for kind in ["graph", "vm", "debug"]:
             mod = tvm.ir.IRModule.from_expr(func)
             intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
             op_res = intrp.evaluate()(*data)
-            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
+
 
 def test_dynamic_to_static_reshape():
     def verify_reshape(shape, newshape, oshape):
@@ -48,7 +49,8 @@ def test_dynamic_to_static_reshape():
         y = relay.var("y", relay.TensorType(newshape, "float32"))
         z = relay.reshape(x, relay.shape_of(y))
         func = run_infer_type(relay.Function([x, y], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                             transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -64,6 +66,7 @@ def test_dynamic_to_static_reshape():
     verify_reshape((2, 3, 4), (8, 3), (8, 3))
     verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
 
+
 def test_dynamic_to_static_double_reshape():
     def verify_reshape(shape, newshape):
         x = relay.var("x", relay.TensorType(shape, "float32"))
@@ -71,7 +74,8 @@ def test_dynamic_to_static_double_reshape():
         z = relay.reshape(x, relay.shape_of(y))
         z = relay.reshape(z, relay.shape_of(x))
         func = run_infer_type(relay.Function([x, y], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                             transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -86,6 +90,7 @@ def test_dynamic_to_static_double_reshape():
     verify_reshape((2, 3, 4), (8, 3))
     verify_reshape((4, 7), (2, 7, 2))
 
+
 def test_dynamic_to_static_quad_reshape():
     def verify_reshape(shape, newshape):
         x = relay.var("x", relay.TensorType(shape, "float32"))
@@ -95,7 +100,8 @@ def test_dynamic_to_static_quad_reshape():
         z3 = relay.reshape(z2, relay.shape_of(z1))
         z4 = relay.reshape(z3, relay.shape_of(z2))
         func = run_infer_type(relay.Function([x, y], z4))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                             transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -110,13 +116,15 @@ def test_dynamic_to_static_quad_reshape():
     verify_reshape((2, 3, 4), (8, 3))
     verify_reshape((4, 7), (2, 7, 2))
 
+
 def test_dynamic_to_static_tile():
     def verify_tile(shape, reps, oshape):
         x = relay.var("x", relay.TensorType(shape, "float32"))
         y = relay.var("y", relay.TensorType(reps, "float32"))
         z = relay.tile(x, relay.shape_of(y))
         func = run_infer_type(relay.Function([x, y], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                             transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -131,6 +139,7 @@ def test_dynamic_to_static_tile():
     verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20))
     verify_tile((4, 7), (4, 2), (16, 14))
 
+
 def test_dynamic_to_static_topk():
     def verify_topk(k, axis, ret_type, is_ascend, dtype):
         shape = (20, 100)
@@ -159,7 +168,8 @@ def test_dynamic_to_static_topk():
                 np_values[i, :] = np_data[i, np_indices[i, :]]
         np_indices = np_indices.astype(dtype)
 
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                             transform.InferType())
         zz = func2.body
         assert isinstance(zz, relay.Call)
         assert zz.op == relay.op.get("topk")
@@ -177,12 +187,15 @@ def test_dynamic_to_static_topk():
                     tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
                 else:
                     tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
+
     np.random.seed(0)
     for k in [0, 1, 5]:
         for axis in [0, -1, 1]:
             for ret_type in ["both", "values", "indices"]:
                 verify_topk(k, axis, ret_type, True, "int64")
                 verify_topk(k, axis, ret_type, False, "float32")
+
+
 def test_dynamic_to_static_broadcast_to():
     def verify_broadcast_to(shape, broadcast_shape):
         x = relay.var("x", relay.TensorType(shape, "float32"))
@@ -190,28 +203,32 @@ def test_dynamic_to_static_broadcast_to():
         z = relay.broadcast_to(x, shape=relay.shape_of(y))
 
         func = run_infer_type(relay.Function([x, y], z))
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                             transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
         assert zz.op == relay.op.get("broadcast_to")
         assert zz.checked_type == relay.ty.TensorType(broadcast_shape, "float32")
-        
+
         x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
         y_data = np.random.uniform(low=-1, high=1, size=broadcast_shape).astype("float32")
-        
+
         ref_res = np.broadcast_to(x_data, y_data.shape)
         verify_func(func2, [x_data, y_data], ref_res)
+
     verify_broadcast_to((3, 1), (3, 3))
-    
+
+
 def test_dynamic_to_static_zeros_ones():
     def verify_ones_zeros(shape, dtype):
         for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
             x = relay.var("x", relay.TensorType(shape, dtype))
             y = op(relay.shape_of(x), dtype)
-            
+
             func = run_infer_type(relay.Function([x], y))
-            func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+            func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                                 transform.InferType())
 
             zz = func2.body
             assert isinstance(zz, relay.Constant)
@@ -224,6 +241,39 @@ def test_dynamic_to_static_zeros_ones():
     verify_ones_zeros((1, 2, 3), 'int64')
     verify_ones_zeros((9, 8, 3, 4), 'float32')
 
+
+def test_dynamic_to_static_resize():
+    def verify_resize(shape, scale, method, layout):
+        if layout == "NHWC":
+            size = (shape[1] * scale, shape[2] * scale)
+        else:
+            size = (shape[2] * scale, shape[3] * scale)
+
+        x = relay.var("x", relay.TensorType(shape, "float32"))
+        size_var = relay.const(np.array(size).astype("float32"))
+        z = relay.image.resize(x, size_var, layout, method, "align_corners")
+
+        func = run_infer_type(relay.Function([x], z))
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                             transform.InferType())
+
+        zz = func2.body
+        assert isinstance(zz, relay.Call)
+        assert zz.op == relay.op.get("image.resize")
+
+        x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
+
+        if method == "bilinear":
+            ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout)
+        else:
+            ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout)
+        verify_func(func2, [x_data], ref_res, rtol=1e-4, atol=1e-6)
+
+    for method in ["bilinear", "nearest_neighbor"]:
+        for layout in ["NCHW", "NHWC"]:
+            verify_resize((1, 4, 4, 4), 2, method, layout)
+
+
 def test_dynamic_to_static_one_hot():
     def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
         indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
@@ -233,7 +283,8 @@ def test_dynamic_to_static_one_hot():
         out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype)
         func = relay.Function([indices], out)
 
-        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()),
+                             transform.InferType())
 
         zz = func2.body
         assert isinstance(zz, relay.Call)
@@ -250,7 +301,8 @@ def test_dynamic_to_static_one_hot():
     _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
     _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
-if __name__=="__main__":
+
+if __name__ == "__main__":
     test_dynamic_to_static_reshape()
     test_dynamic_to_static_double_reshape()
     test_dynamic_to_static_quad_reshape()
@@ -258,3 +310,5 @@ if __name__=="__main__":
     test_dynamic_to_static_topk()
     test_dynamic_to_static_broadcast_to()
     test_dynamic_to_static_zeros_ones()
+    test_dynamic_to_static_resize()
+    test_dynamic_to_static_one_hot()