[RELAY][DYN] Dynamic upsampling relay op (#6273)
authorLily Orth-Smith <lilyorthsmith@gmail.com>
Fri, 21 Aug 2020 16:06:53 +0000 (09:06 -0700)
committerGitHub <noreply@github.com>
Fri, 21 Aug 2020 16:06:53 +0000 (09:06 -0700)
* implementing upsampling op

* fix lint

* fix lint again

* add doc to upsampling shape func

* fix set attrs build problem

* fixing imports

* reverting data layout transform changes

* moved layout template to header file

* changing python module from nn.dyn to dyn.nn

* adding support for more layouts to upsampling

* fix lint

* fix upsampling doc

* change _nn.py doc

* failed flakey test

* fix build after merge

python/tvm/relay/op/dyn/nn/_nn.py
python/tvm/relay/op/nn/nn.py
python/tvm/topi/nn/upsampling.py
src/relay/op/dyn/nn/upsampling.cc [new file with mode: 0644]
src/relay/op/make_op.h
src/relay/op/nn/upsampling.cc
src/relay/op/nn/upsampling.h [new file with mode: 0644]
src/relay/transforms/dynamic_to_static.cc
tests/python/relay/dyn/test_dynamic_op_level2.py
tests/python/relay/test_pass_dynamic_to_static.py

index 141fc22..a263561 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in
-"""Backend compiler related feature registration"""
+"""Backend compiler related feature registration for dynamic relay ops in nn namespace"""
 
 from __future__ import absolute_import
 
+from tvm import topi
+
+from tvm.runtime import convert
 from tvm.te.hybrid import script
-from ...op import register_shape_func
-from ...op import register_broadcast_schedule
+from ...op import register_shape_func, register_compute
+from ...op import register_injective_schedule, register_broadcast_schedule
 
-# pad
+# upsampling
+@register_compute("dyn.nn.upsampling")
+def compute_upsampling(attrs, inputs, out_dtype):
+    data = inputs[0]
+    scale_h = inputs[1]
+    scale_w = inputs[2]
+    layout = attrs.layout
+    method = attrs.method
+    align_corners = attrs.align_corners
+    return [topi.nn.upsampling(data, scale_h, scale_w, layout,
+                               method, align_corners, out_dtype.shape)]
+
+register_injective_schedule("dyn.nn.upsampling")
 register_broadcast_schedule("dyn.nn.pad")
 
 #####################
 #  Shape functions  #
 #####################
 
+# upsampling
+@script
+def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis, channel_axis):
+    out = output_tensor((4,), "int64")
+    out[0] = int64(dshape[0])
+    out[height_axis] = int64(round(dshape[height_axis] * scale_h[0]))
+    out[width_axis] = int64(round(dshape[width_axis] * scale_w[0]))
+    out[channel_axis] = int64(dshape[channel_axis])
+    return out
+
+@register_shape_func("dyn.nn.upsampling", True)
+def upsampling_shape_func(attrs, inputs, _):
+    """Shape function for upsampling. Supports NCHW and NHWC layouts."""
+    layout = attrs.layout
+    height_axis = width_axis = 1
+    for i, letter in enumerate(layout):
+        if letter == "H":
+            height_axis = i
+        if letter == "W":
+            width_axis = i
+        if letter == "C":
+            channel_axis = i
+    return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2],
+                                   convert(height_axis), convert(width_axis),
+                                   convert(channel_axis))]
+# pad
 @script
 def _dyn_pad_shape_func(data, pad_width):
     ndim = len(data.shape)
index c04db30..6f6849a 100644 (file)
@@ -1152,10 +1152,10 @@ def upsampling(data,
     data : tvm.relay.Expr
         The input data to the operator.
 
-    scale_h : tvm.relay.Expr
+    scale_h : tvm.relay.Expr or int or float
         The scale factor for height upsampling.
 
-    scale_w : tvm.relay.Expr
+    scale_w : tvm.relay.Expr or int or float
         The scale factor for width upsampling.
 
     layout : str, optional
@@ -1172,6 +1172,12 @@ def upsampling(data,
     result : tvm.relay.Expr
         The computed result.
     """
+    if isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
+        if not isinstance(scale_h, Expr):
+            scale_h = const(scale_h, "float64")
+        if not isinstance(scale_w, Expr):
+            scale_w = const(scale_w, "float64")
+        return _dyn_make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
     return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
 
 
index d8da41f..db4af06 100644 (file)
@@ -21,7 +21,7 @@ from ..util import simplify
 
 
 def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
-               align_corners=False):
+               align_corners=False, output_shape=None):
     """Perform upsampling on the data.
        Nearest neighbor and bilinear upsampling are supported.
 
@@ -52,16 +52,30 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
     """
     base_layout = layout[0:4]
     if base_layout == "NCHW":
-        out_shape = (simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)),
-                     simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype)))
+        if not output_shape: #static case
+            scaled_h = data.shape[2] * scale_h
+            scaled_w = data.shape[3] * scale_w
+            reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
+                            simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)))
+        else: #dynamic case -- we don't need to scale; already done in shape func
+            reshape_size = (simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)),
+                            simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype)))
     elif layout == "NHWC":
-        out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)),
-                     simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype)))
+        if not output_shape: #static case
+            scaled_h = data.shape[1] * scale_h
+            scaled_w = data.shape[2] * scale_w
+            reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[1].dtype)),
+                            simplify(topi.cast(te.round(scaled_w), data.shape[2].dtype)))
+        else: #dynamic case
+            reshape_size = (simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)),
+                            simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)))
+
     else:
         raise ValueError("not support this layout {} yet".format(layout))
     coord_trans = "align_corners" if align_corners else "asymmetric"
-    return topi.image.resize(data, out_shape, layout=layout,
-                             method=method, coordinate_transformation_mode=coord_trans)
+    return topi.image.resize(data, reshape_size, layout=layout,
+                             method=method, coordinate_transformation_mode=coord_trans,
+                             output_shape=output_shape)
 
 
 def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc
new file mode 100644 (file)
index 0000000..e271848
--- /dev/null
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file upsampling.cc
+ * \brief upsampling operator
+ */
+
+#include "../../nn/upsampling.h"
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
+
+#include <vector>
+
+#include "../../op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace dyn {
+
+bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                   const TypeReporter& reporter) {
+  // types = [data_type, scale_h_type, scale_w_type, ret_type]
+  CHECK_EQ(types.size(), 4);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* scale_h = types[1].as<TensorTypeNode>();
+  const auto* scale_w = types[2].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  if (scale_h == nullptr) return false;
+  if (scale_w == nullptr) return false;
+
+  CHECK_EQ(data->shape.size(), 4);
+  CHECK_EQ(scale_h->shape.size(), 0);
+  CHECK_EQ(scale_w->shape.size(), 0);
+  static const Layout kNCHW("NCHW");
+
+  const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
+  CHECK(param);
+  const Layout in_layout(param->layout);
+
+  auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
+  CHECK(layout_converter.defined())
+      << "UpSampling only supports input layouts that are convertible from NCHW."
+      << " But got " << in_layout;
+
+  auto nchw_oshape = layout_converter.ForwardShape(data->shape);
+
+  nchw_oshape.Set(2, Any());
+  nchw_oshape.Set(3, Any());
+  auto oshape = layout_converter.BackwardShape(nchw_oshape);
+
+  reporter->Assign(types[3], TensorType(oshape, data->dtype));
+  return true;
+}
+
+// Positional relay function to create upsampling operator
+// used by frontend FFI.
+Expr MakeUpSampling(Expr data, Expr scale_h, Expr scale_w, String layout, String method,
+                    bool align_corners) {
+  auto attrs = make_object<UpSamplingAttrs>();
+  attrs->layout = std::move(layout);
+  attrs->method = std::move(method);
+  attrs->align_corners = align_corners;
+
+  static const Op& op = Op::Get("dyn.nn.upsampling");
+  return Call(op, {data, scale_h, scale_w}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.upsampling").set_body_typed(MakeUpSampling);
+
+RELAY_REGISTER_OP("dyn.nn.upsampling")
+    .describe(
+        R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.
+
+- **data**: data is 4D array of shape
+            (batch_size, channels, in_height, in_width) for NCHW
+            (batch_size, in_height, in_width, channels) for NHWC
+
+- **scale_h**: scale_h is an integer of the amount to scale height by
+
+- **scale_w**: scale_w is an integer of the amount to scale width by
+
+- **out**: Output is 4D array of shape
+           for layout NCHW
+           (batch_size, channels, in_height*scale, in_width*scale)
+
+           for layout NHWC
+           (batch_size, in_height*scale, in_width*scale, channels)
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<UpSamplingAttrs>()
+    .set_num_inputs(3)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("scale_h", "double", "The scale for the height.")
+    .add_argument("scale_w", "double", "The scale for the width.")
+    .set_support_level(2)
+    .add_type_rel("DynamicUpSampling", UpSamplingRel)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+                                   UpsamplingInferCorrectLayout<UpSamplingAttrs>)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
+}  // namespace dyn
+}  // namespace relay
+}  // namespace tvm
index c759be3..fb3bf02 100644 (file)
@@ -74,6 +74,9 @@ Expr MakeTile(Expr data, Array<Integer> reps);
 
 Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype);
 
+Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method,
+                    bool align_corners);
+
 Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
                   bool unbiased);
 
index cb20881..bdf3090 100644 (file)
  * \file upsampling.cc
  * \brief upsampling operator
  */
+
+#include "upsampling.h"
+
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/tir/data_layout.h>
 
+#include <utility>
 #include <vector>
 
 #include "../op_common.h"
@@ -36,33 +40,6 @@ namespace relay {
 TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);
 TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs);
 
-template <typename T>
-Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
-                                                   const Array<Layout>& new_in_layouts,
-                                                   const Array<Layout>& old_in_layouts,
-                                                   const Array<tvm::relay::Type>& old_in_types) {
-  // NOTE: Discard "const" qualifier here.
-  T* params = const_cast<T*>(attrs.as<T>());
-
-  if (new_in_layouts.defined()) {
-    CHECK_EQ(new_in_layouts.size(), 1);
-
-    Layout raw_layout(params->layout);
-    Layout input = new_in_layouts[0];
-    if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
-        input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
-        !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
-        (input.IndexOf(LayoutAxis::Get('D')) == -1 ||
-         (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
-          !input.Contains(LayoutAxis::Get('d'))))) {
-      params->layout = input.name();  // modify self to follow the input layout
-    }
-  }
-
-  Layout inferred_layout(params->layout);
-  return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
-}
-
 bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                    const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 2);
diff --git a/src/relay/op/nn/upsampling.h b/src/relay/op/nn/upsampling.h
new file mode 100644 (file)
index 0000000..e4e3bc9
--- /dev/null
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file src/relay/op/nn/upsampling.h
+ * \brief implementation of the InferCorrectLayout pass for upsampling
+ */
+
+#ifndef TVM_RELAY_OP_NN_UPSAMPLING_H_
+#define TVM_RELAY_OP_NN_UPSAMPLING_H_
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/tir/data_layout.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relay {
+
+template <typename T>
+Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
+                                                   const Array<Layout>& new_in_layouts,
+                                                   const Array<Layout>& old_in_layouts,
+                                                   const Array<tvm::relay::Type>& old_in_types) {
+  // NOTE: Discard "const" qualifier here.
+  T* params = const_cast<T*>(attrs.as<T>());
+
+  if (new_in_layouts.defined()) {
+    CHECK_EQ(new_in_layouts.size(), 1);
+
+    Layout raw_layout(params->layout);
+    Layout input = new_in_layouts[0];
+    if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
+        input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
+        !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
+        (input.IndexOf(LayoutAxis::Get('D')) == -1 ||
+         (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
+          !input.Contains(LayoutAxis::Get('d'))))) {
+      params->layout = input.name();  // modify self to follow the input layout
+    }
+  }
+
+  Layout inferred_layout(params->layout);
+  return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
+}
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_OP_NN_UPSAMPLING_H_
index 3de773e..629f5af 100644 (file)
@@ -124,6 +124,21 @@ class DynamicToStaticMutator : public MixedModeMutator {
            }
            return Expr(nullptr);
          }},
+        {Op::Get("dyn.nn.upsampling"),
+         [](const CallNode* call_node) {
+           const ConstantNode* scale_h = call_node->args[1].as<ConstantNode>();
+           const ConstantNode* scale_w = call_node->args[2].as<ConstantNode>();
+           if (scale_h && scale_w) {
+             CHECK_EQ(scale_h->data->ndim, 0);
+             CHECK_EQ(scale_w->data->ndim, 0);
+             const UpSamplingAttrs* param = call_node->attrs.as<UpSamplingAttrs>();
+             CHECK(param);
+             return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data),
+                                   ToScalar(scale_w->data), param->layout, param->method,
+                                   param->align_corners);
+           }
+           return Expr(nullptr);
+         }},
         {Op::Get("dyn.nn.pad"),
          [](const CallNode* call_node) {
            const ConstantNode* pad_width = call_node->args[1].as<ConstantNode>();
index 137febd..e1a0d28 100644 (file)
@@ -27,6 +27,55 @@ from test_dynamic_op_level3 import verify_func
 import tvm.topi.testing
 from tvm.relay.testing import run_infer_type
 
+def test_dyn_upsampling_run():
+    def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=False):
+
+        if layout == "NCHW":
+            (n, c, h, w) = dshape
+            x_data = np.random.uniform(size=(n, c, h, w)).astype("float32")
+
+        elif layout == "NHWC":
+            (n, h, w, c) = dshape
+            x_data = np.random.uniform(size=(n, h, w, c)).astype("float32")
+
+        if method == "nearest_neighbor":
+            ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout)
+        else:
+            ref_res = tvm.topi.testing.bilinear_resize_python(x_data, (int(round(h*scale_h)),
+                                                  int(round(w*scale_w))), layout)
+        x = relay.Var("x", relay.TensorType(dshape, "float32"))
+        scale_h_var = relay.var("scale_h", relay.TensorType((), "float32"))
+        scale_w_var = relay.var("scale_h", relay.TensorType((), "float32"))
+
+        z = relay.nn.upsampling(x, scale_h_var, scale_w_var, method=method, layout=layout, align_corners=align_corners)
+        zz = run_infer_type(z)
+        func = relay.Function([x, scale_h_var, scale_w_var], z)
+
+        for target, ctx in ctx_list():
+             if "llvm" not in target: continue
+             for kind in ["vm", "debug"]:
+                 mod = tvm.ir.IRModule.from_expr(func)
+                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                 op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32"))
+                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
+
+    verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NCHW", "nearest_neighbor")
+    verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NCHW", "bilinear", True)
+    verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NHWC", "nearest_neighbor")
+    verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True)
+
+#tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable
+def test_dyn_upsampling_infer_type_const():
+    n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+
+    data = relay.var("data", relay.TensorType((n, c, h, w), "int8"))
+    scale_h = relay.Var("scale_h", relay.TensorType((), "float32"))
+    scale_w = relay.Var("scale_w", relay.TensorType((), "float32"))
+
+    z = relay.nn.upsampling(data, 2.0, scale_w)
+    zz = run_infer_type(z)
+    assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8")
+
 def test_dyn_pad():
     def verify_pad(dshape, pad_width, pad_val, dtype):
         x = relay.var("x", relay.TensorType(dshape, dtype))
@@ -66,3 +115,5 @@ def test_dyn_pad():
 
 if __name__ == "__main__":
     test_dyn_pad()
+    test_dyn_upsampling_infer_type_const()
+    test_dyn_upsampling_run()
index ed9b94c..c47d959 100644 (file)
@@ -320,6 +320,27 @@ def test_dynamic_to_static_full():
     verify_full(4, (1, 2, 3, 4), 'int32')
     verify_full(4.0, (1, 2, 8, 10), 'float32')
 
+def test_dynamic_to_static_upsampling():
+    def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype):
+        x = relay.var("x", relay.TensorType(data_shape, dtype))
+        scale_h = relay.const(scale_h_val)
+        scale_w = relay.const(scale_w_val)
+        z = relay.nn.upsampling(x, scale_h, scale_w)
+
+        func = run_infer_type(relay.Function([x], z))
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+
+        zz = func2.body
+        assert isinstance(zz, relay.Call)
+        assert zz.op == relay.op.get("nn.upsampling")
+
+        x_data = np.random.uniform(size=data_shape).astype(dtype)
+        ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW")
+        verify_func(func2, [x_data], ref_res)
+
+    verify_upsampling((1, 16, 32, 32), 2, 2, 'int8')
+    verify_upsampling((1, 16, 32, 32), 4, 4, 'int32')
+
 def test_dynamic_to_static_pad():
     def verify_pad(data_shape, pad_width, pad_val, dtype):
         x = relay.var("x", relay.TensorType(data_shape, dtype))
@@ -337,7 +358,6 @@ def test_dynamic_to_static_pad():
     verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32")
     verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64")
 
-
 if __name__ == "__main__":
     test_dynamic_to_static_reshape()
     test_dynamic_to_static_double_reshape()
@@ -349,4 +369,5 @@ if __name__ == "__main__":
     test_dynamic_to_static_resize()
     test_dynamic_to_static_one_hot()
     test_dynamic_to_static_full()
+    test_dynamic_to_static_upsampling()
     test_dynamic_to_static_pad()