[ONNX] Update Slice op conversion to take strides into account, clean up tests (...
authormasahi <masahi129@gmail.com>
Sun, 13 Sep 2020 17:30:46 +0000 (02:30 +0900)
committerGitHub <noreply@github.com>
Sun, 13 Sep 2020 17:30:46 +0000 (10:30 -0700)
Co-authored-by: masa <masa@pop-os.localdomain>
python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index 5f31724..74ac74e 100644 (file)
@@ -1024,6 +1024,9 @@ class Slice(OnnxOpConverter):
         attrs = {"starts": inputs[1], "ends": inputs[2]}
         if len(inputs) >= 4:
             attrs["axes"] = inputs[3]
+        if len(inputs) >= 5:
+            attrs["steps"] = inputs[4]
+
         attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()}
         attrs = {
             k: params[v[1]].asnumpy()
@@ -1033,12 +1036,23 @@ class Slice(OnnxOpConverter):
         }
 
         # Update the starts and ends according to axes if required.
-        if "axes" in attrs:
-            if max(attrs["axes"] + 1) != len(attrs["axes"]):
-                new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"])
-                attrs["starts"] = new_starts
-                attrs["ends"] = new_ends
-        return _op.strided_slice(inputs[0], begin=list(attrs["starts"]), end=list(attrs["ends"]))
+        if "axes" in attrs and max(attrs["axes"] + 1) != len(attrs["axes"]):
+            new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"])
+            attrs["starts"] = new_starts
+            attrs["ends"] = new_ends
+
+        begins = list(attrs["starts"])
+        ends = list(attrs["ends"])
+        strides = [1] * len(begins)
+
+        if "steps" in attrs:
+            steps = list(attrs["steps"])
+            axes = attrs["axes"]
+            assert len(steps) == len(axes)
+            for axis, step in zip(axes, steps):
+                strides[axis] = step
+
+        return _op.strided_slice(inputs[0], begin=begins, end=ends, strides=strides)
 
 
 class Gather(OnnxOpConverter):
index 894a6b6..81c8e77 100644 (file)
@@ -20,10 +20,8 @@ import onnx
 from onnx import helper, TensorProto, mapping
 import torch
 import torchvision
-from tvm import topi
 import tvm.topi.testing
 import tvm
-from tvm import te
 from tvm import relay
 from tvm.contrib import graph_runtime
 import scipy
@@ -52,9 +50,10 @@ def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None):
     mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
 
     ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
-    indata = tvm.nd.array(input_data)
-    result = ex.evaluate()(indata)
-    return result.asnumpy()
+    result = ex.evaluate()(*input_data)
+    if isinstance(result, tvm.runtime.NDArray):
+        return result.asnumpy()
+    return [r.asnumpy() for r in result]
 
 
 def get_tvm_output(
@@ -104,21 +103,71 @@ def get_onnxruntime_output(model, inputs, dtype="float32"):
 
     rep = onnxruntime.backend.prepare(model, "CPU")
     if isinstance(inputs, list) and len(inputs) > 1:
-        ort_out = rep.run(inputs)
+        return rep.run(inputs)
+    elif isinstance(inputs, list) and len(inputs) == 1:
+        inp = inputs[0]
     else:
-        x = inputs.astype(dtype)
-        ort_out = rep.run(x)[0]
-    return ort_out
+        inp = inputs
+    return rep.run(inp.astype(dtype))[0]
+
+
+def verify_with_ort_with_inputs(
+    model,
+    inputs,
+    out_shape=None,
+    targets=None,
+    use_vm=False,
+    opset=None,
+    dtype="float32",
+    rtol=1e-5,
+    atol=1e-5,
+):
+    def flatten(out):
+        if isinstance(out, list) and len(out) == 1:
+            out = out[0]
+        if isinstance(out, np.ndarray):
+            return out.flatten()
+        return out
 
+    ort_out = get_onnxruntime_output(model, inputs, dtype)
 
-def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
-    dtype = "float32"
-    x = np.random.uniform(size=data_shape)
-    model = onnx.load_model(graph_file)
-    c2_out = get_onnxruntime_output(model, x, dtype)
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
-        tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
+    if targets is None:
+        targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()]
+
+    for target in targets:
+        ctx = tvm.context(target, 0)
+
+        if use_vm:
+            tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=opset)
+        else:
+            tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, dtype, opset=opset)
+
+        tvm.testing.assert_allclose(flatten(ort_out), flatten(tvm_out), rtol=rtol, atol=atol)
+
+
+def verify_with_ort(
+    model,
+    input_shapes,
+    out_shape=None,
+    targets=None,
+    use_vm=False,
+    opset=None,
+    dtype="float32",
+    rtol=1e-5,
+    atol=1e-5,
+):
+    inputs = [np.random.uniform(size=ishape).astype(dtype) for ishape in input_shapes]
+    verify_with_ort_with_inputs(
+        model,
+        inputs,
+        out_shape=out_shape,
+        targets=targets,
+        use_vm=use_vm,
+        opset=opset,
+        dtype=dtype,
+        rtol=rtol,
+        atol=atol,
+    )
 
 
 def make_constant_node(name, data_type, dims, vals):
@@ -161,8 +210,7 @@ def test_reshape():
     for target, ctx in tvm.testing.enabled_targets():
         x = np.random.uniform(size=in_shape).astype("int32")
         tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32")
-
-    tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
+        tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
 
 @tvm.testing.uses_gpu
@@ -193,8 +241,7 @@ def test_expand():
 
         for target, ctx in tvm.testing.enabled_targets():
             tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, "float32")
-
-        tvm.testing.assert_allclose(ref_data, tvm_out)
+            tvm.testing.assert_allclose(ref_data, tvm_out)
 
     in_shape = (3, 1)
     shape = (3, 4)
@@ -221,11 +268,7 @@ def verify_depth_to_space(inshape, outshape, mode, blockSize):
 
     model = helper.make_model(graph, producer_name="depth_to_space_test")
 
-    for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=inshape).astype("float32")
-        tvm_out = get_tvm_output(model, x, target, ctx, outshape, "float32")
-        onnx_out = get_onnxruntime_output(model, x, "float32")
-        tvm.testing.assert_allclose(onnx_out, tvm_out)
+    verify_with_ort(model, [inshape], outshape)
 
 
 @tvm.testing.uses_gpu
@@ -248,11 +291,7 @@ def verify_space_to_depth(inshape, outshape, blockSize):
 
     model = helper.make_model(graph, producer_name="space_to_depth_test")
 
-    for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=inshape).astype("float32")
-        tvm_out = get_tvm_output(model, x, target, ctx, outshape, "float32")
-        onnx_out = get_onnxruntime_output(model, x, "float32")
-        tvm.testing.assert_allclose(onnx_out, tvm_out)
+    verify_with_ort(model, [inshape], outshape)
 
 
 @tvm.testing.uses_gpu
@@ -293,8 +332,7 @@ def test_shape():
     for target, ctx in tvm.testing.enabled_targets():
         x = np.random.uniform(size=in_shape).astype("int32")
         tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "int32")
-
-    tvm.testing.assert_allclose(ref_shape, tvm_out)
+        tvm.testing.assert_allclose(ref_shape, tvm_out)
 
 
 def _test_power_iteration(x_shape, y_shape):
@@ -350,8 +388,7 @@ def test_squeeze():
     for target, ctx in tvm.testing.enabled_targets():
         x = np.random.uniform(size=in_shape).astype("float32")
         tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32")
-
-    tvm.testing.assert_allclose(out_shape, tvm_out.shape)
+        tvm.testing.assert_allclose(out_shape, tvm_out.shape)
 
 
 @tvm.testing.uses_gpu
@@ -375,8 +412,7 @@ def test_flatten():
     for target, ctx in tvm.testing.enabled_targets():
         x = np.random.uniform(size=in_shape).astype("int32")
         tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32")
-
-    tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
+        tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
 
 @tvm.testing.uses_gpu
@@ -398,8 +434,7 @@ def test_unsqueeze():
     for target, ctx in tvm.testing.enabled_targets():
         x = np.random.uniform(size=in_shape).astype("float32")
         tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32")
-
-    tvm.testing.assert_allclose(out_shape, tvm_out.shape)
+        tvm.testing.assert_allclose(out_shape, tvm_out.shape)
 
 
 def verify_gather(in_shape, indices, axis, dtype):
@@ -450,11 +485,8 @@ def verify_gatherelements(in_shape, indices, axis):
         outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))],
     )
     model = helper.make_model(graph, producer_name="gather_elements_test")
-    onnx_out = get_onnxruntime_output(model, [x, indices])
 
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x, indices], target, ctx, onnx_out[0].shape)
-        tvm.testing.assert_allclose(onnx_out[0], tvm_out)
+    verify_with_ort_with_inputs(model, [x, indices])
 
 
 @tvm.testing.uses_gpu
@@ -491,11 +523,7 @@ def verify_scatter(in_shape, indices, axis):
         outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))],
     )
     model = helper.make_model(graph, producer_name="scatter_test")
-    onnx_out = get_onnxruntime_output(model, [x, indices, updates])
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x, indices, updates], target, ctx, onnx_out[0].shape)
-        tvm.testing.assert_allclose(onnx_out[0], tvm_out)
+    verify_with_ort_with_inputs(model, [x, indices, updates])
 
 
 @tvm.testing.uses_gpu
@@ -525,14 +553,14 @@ def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None):
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=1)
-
-    tvm.testing.assert_allclose(outdata, tvm_out)
+        tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 def _test_slice_iteration_v10(indata, outdata, **attrs):
     starts = attrs["starts"]
     ends = attrs["ends"]
     axes = None if "axes" not in attrs else attrs["axes"]
+    steps = None if "steps" not in attrs else attrs["steps"]
     starts = np.asarray(starts)
     ends = np.asarray(ends)
     inputs = [
@@ -589,8 +617,8 @@ def _test_slice_iteration_v10(indata, outdata, **attrs):
             return [ref_node, ref_node2, reshape1_node, reshape2_node]
 
     slice_inputs = []
-    for attr_name in ["starts", "ends", "axes"]:
-        if attr_name == "axes" and not axes:
+    for attr_name in ["starts", "ends", "axes", "steps"]:
+        if attr_name not in attrs:
             continue
         if "add_noop_to_input_attrs" in attrs and attr_name in attrs["add_noop_to_input_attrs"]:
             nodes.extend(add_noop_to_input_attr(attr_name, attrs[attr_name]))
@@ -602,6 +630,13 @@ def _test_slice_iteration_v10(indata, outdata, **attrs):
         axes = np.asarray(axes)
         inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT32, list(axes.shape)))
         initializer.append(helper.make_tensor("axes", TensorProto.INT32, list(axes.shape), axes))
+
+    if steps:
+        assert axes is not None and len(axes) == len(steps)
+        steps = np.asarray(steps)
+        inputs.append(helper.make_tensor_value_info("steps", TensorProto.INT32, list(axes.shape)))
+        initializer.append(helper.make_tensor("steps", TensorProto.INT32, list(steps.shape), steps))
+
     y = helper.make_node("Slice", ["data", *slice_inputs], ["out"])
 
     nodes.append(y)
@@ -616,8 +651,7 @@ def _test_slice_iteration_v10(indata, outdata, **attrs):
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=10)
-
-    tvm.testing.assert_allclose(outdata, tvm_out)
+        tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 @tvm.testing.uses_gpu
@@ -681,6 +715,19 @@ def test_slice():
         x, x, starts=(0, 0), ends=(9223372036854775807, 9223372036854775807), axes=(0, 3)
     )
 
+    x = np.random.randn(4, 4).astype(np.float32)
+    _test_slice_iteration_v10(
+        x, x[:, 1::2], starts=(1,), ends=(9223372036854775807,), axes=(1,), steps=(2,)
+    )
+    _test_slice_iteration_v10(
+        x,
+        x[0::1, 1::2],
+        starts=(0, 1),
+        ends=(4, 4),
+        axes=(0, 1),
+        steps=(1, 2),
+    )
+
 
 def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
     indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
@@ -699,8 +746,7 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
-
-    tvm.testing.assert_allclose(outdata, tvm_out)
+        tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 @tvm.testing.uses_gpu
@@ -742,11 +788,7 @@ def test_clip_min_max_as_inputs():
     )
     model = helper.make_model(graph, producer_name="clip_test")
 
-    indata = np.random.uniform(-1, 7, size=input_shape).astype("float32")
-    onnx_out = get_onnxruntime_output(model, indata, "float32")
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, indata, target, ctx, input_shape, "float32")
-    tvm.testing.assert_allclose(onnx_out, tvm_out)
+    verify_with_ort(model, [input_shape], input_shape)
 
 
 @tvm.testing.uses_gpu
@@ -771,8 +813,7 @@ def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs):
 
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
-
-    tvm.testing.assert_allclose(outdata, tvm_out)
+        tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 @tvm.testing.uses_gpu
@@ -1574,10 +1615,7 @@ def verify_reduce_func(func, data, axis, keepdims):
 
     model = helper.make_model(graph, producer_name="reduce_test")
 
-    onnx_out = get_onnxruntime_output(model, data, "float32")
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, data, target, ctx, outshape, "float32")
-        tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [data], outshape)
 
 
 @tvm.testing.uses_gpu
@@ -1815,15 +1853,7 @@ def test_prelu():
 
         model = helper.make_model(graph, producer_name="prelu_test")
 
-        indata = np.random.uniform(-10, 10, x_shape).astype(np.float32)
-        slopedata = np.random.uniform(-10, 10, a_shape).astype(np.float32)
-        onnx_out = get_onnxruntime_output(model, [indata, slopedata])
-
-        for target, ctx in [("llvm", tvm.cpu())]:
-            tvm_out = get_tvm_output(
-                model, [indata, slopedata], target, ctx, list(x_shape), output_dtype="float32"
-            )
-            tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05)
+        verify_with_ort(model, [x_shape, a_shape], list(x_shape))
 
     verify_prelu([3, 4, 5, 6], [1, 4, 1, 1])
     verify_prelu([1, 8, 5, 6], [1, 8, 1, 1])
@@ -1900,11 +1930,8 @@ def check_torch_conversion(model, input_size):
     # Set verbose=True for more output
     torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
     onnx_model = onnx.load(file_name)
-    for target, ctx in tvm.testing.enabled_targets():
-        input_data = np.random.uniform(size=input_size).astype("int32")
-        c2_out = get_onnxruntime_output(onnx_model, input_data)
-        tvm_out = get_tvm_output(onnx_model, input_data, target, ctx)
-        tvm.testing.assert_allclose(c2_out, tvm_out)
+    input_data = np.random.uniform(size=input_size).astype("int32")
+    verify_with_ort_with_inputs(onnx_model, [input_data])
 
 
 @tvm.testing.uses_gpu
@@ -2244,18 +2271,9 @@ def test_batch_norm():
         )
 
         model = helper.make_model(graph, producer_name="batchnorm_test")
-
-        for target, ctx in tvm.testing.enabled_targets():
-            x = np.random.uniform(size=in_shape).astype("float32")
-            scale = np.random.uniform(size=in_shape[1]).astype("float32")
-            b = np.random.uniform(size=in_shape[1]).astype("float32")
-            mean = np.random.uniform(size=in_shape[1]).astype("float32")
-            var = np.random.uniform(size=in_shape[1]).astype("float32")
-            onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], "float32")[0]
-            tvm_out = get_tvm_output(
-                model, [x, scale, b, mean, var], target, ctx, in_shape, "float32"
-            )
-            tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+        # X, scale, b, mean, var
+        inshapes = [in_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]]
+        verify_with_ort(model, inshapes, in_shape)
 
     verify_batch_norm([1, 3, 224, 224])
     verify_batch_norm([1, 3, 24, 24])
@@ -2288,19 +2306,9 @@ def test_batch_norm_dynamic_subgraph():
         )
 
         model = helper.make_model(graph, producer_name="batchnorm_test")
-
-        for target, ctx in tvm.testing.enabled_targets():
-            x = np.random.uniform(size=in_shape).astype("float32")
-            inp = np.random.uniform(size=o_shape).astype("float32")
-            scale = np.random.uniform(size=in_shape[1]).astype("float32")
-            b = np.random.uniform(size=in_shape[1]).astype("float32")
-            mean = np.random.uniform(size=in_shape[1]).astype("float32")
-            var = np.random.uniform(size=in_shape[1]).astype("float32")
-            onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], "float32")[0]
-            tvm_out = get_tvm_output(
-                model, [x, inp, scale, b, mean, var], target, ctx, in_shape, "float32"
-            )
-            tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+        # X, inp, scale, b, mean, var
+        inshapes = [in_shape, o_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]]
+        verify_with_ort(model, inshapes, in_shape, use_vm=False)
 
     verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160])
 
@@ -2364,12 +2372,7 @@ def verify_conv(
 
     model = helper.make_model(graph, producer_name="conv_test")
 
-    for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=x_shape).astype("float32")
-        W = np.random.uniform(size=w_shape).astype("float32")
-        tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape)
-        onnx_out = get_onnxruntime_output(model, [x, W], "float32")[0]
-        tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort(model, [x_shape, w_shape], y_shape)
 
 
 @tvm.testing.uses_gpu
@@ -2476,13 +2479,7 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p):
     )
 
     model = helper.make_model(graph, producer_name="convtranspose_trest")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=x_shape).astype("float32")
-        W = np.random.uniform(size=w_shape).astype("float32")
-        tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape)
-        onnx_out = get_onnxruntime_output(model, [x, W], "float32")[0]
-        tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort(model, [x_shape, w_shape], y_shape)
 
 
 @tvm.testing.uses_gpu
@@ -2548,11 +2545,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p
     )
 
     model = helper.make_model(graph, producer_name="pooling_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        onnx_out = get_onnxruntime_output(model, x_np, "float32")
-        tvm_out = get_tvm_output(model, [x_np], target, ctx, out_shape)
-        tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort(model, [x_shape], out_shape)
 
 
 @tvm.testing.uses_gpu
@@ -2657,12 +2650,7 @@ def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32"):
         outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))],
     )
     model = helper.make_model(graph, producer_name="mod_test")
-
-    onnx_out = get_onnxruntime_output(model, [x_np, y_np], dtype)[0]
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x_np, y_np], target, ctx, out_shape)
-        tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [x_np, y_np], out_shape)
 
 
 @tvm.testing.uses_gpu
@@ -2731,9 +2719,6 @@ def test_xor():
 
 
 def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_shape):
-    x_np = np.random.uniform(size=x_shape).astype("float32")
-    rois_np = np.random.uniform(size=rois_shape).astype("float32")
-
     if spatial_scale is None:
         pool_node = helper.make_node(
             "MaxRoiPool", inputs=["x", "rois"], outputs=["y"], pooled_shape=pooled_shape
@@ -2758,11 +2743,7 @@ def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_sh
     )
 
     model = helper.make_model(graph, producer_name="pool_test")
-
-    onnx_out = get_onnxruntime_output(model, [x_np, rois_np], "float32")[0]
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x_np, rois_np], target, ctx, out_shape)
-        tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort(model, [x_shape, rois_shape], out_shape)
 
 
 @tvm.testing.uses_gpu
@@ -2785,8 +2766,6 @@ def test_max_roi_pool():
 
 
 def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"):
-    x_np = np.random.uniform(size=x_shape).astype("float32")
-
     if pads is None:
         pool_node = helper.make_node(
             "LpPool",
@@ -2816,11 +2795,7 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="
     )
 
     model = helper.make_model(graph, producer_name="lppool_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        onnx_out = get_onnxruntime_output(model, x_np, "float32")
-        tvm_out = get_tvm_output(model, [x_np], target, ctx, out_shape)
-        tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort(model, [x_shape], out_shape)
 
 
 @tvm.testing.uses_gpu
@@ -3228,12 +3203,7 @@ def test_resize():
 
         model = helper.make_model(graph, producer_name="resize_test")
 
-        for target, ctx in tvm.testing.enabled_targets():
-            x = np.random.uniform(size=ishape).astype("float32")
-            onnx_out = get_onnxruntime_output(model, x, "float32")
-            tvm_out = get_tvm_output(model, x, target, ctx, oshape, "float32", opset=11)
-
-            tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
+        verify_with_ort(model, [ishape], oshape, use_vm=False, opset=11)
 
     # upsampling
     verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric")
@@ -3266,11 +3236,9 @@ def test_nonzero():
 
         model = helper.make_model(graph, producer_name="nonzero_test")
 
-        onnx_out = get_onnxruntime_output(model, indata, dtype)
-
-        for target, ctx in [("llvm", tvm.cpu())]:
-            tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9)
-            tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
+        verify_with_ort_with_inputs(
+            model, [indata], targets=["llvm"], dtype="int64", use_vm=True, opset=9
+        )
 
     input_data = np.array([[1, 0], [1, 1]], dtype=np.int64)
     result = np.array((np.nonzero(input_data)))  # expected output [[0, 1, 1], [0, 0, 1]]
@@ -3378,17 +3346,7 @@ def test_roi_align():
         np_rois = np.random.uniform(size=[num_roi, 4]).astype("float32") * input_dims[2]
         np_batch_indicies = np.random.randint(low=0, high=input_dims[0], size=num_roi)
 
-        onnx_out = get_onnxruntime_output(model, [np_data, np_rois, np_batch_indicies])
-        for target, ctx in [("llvm", tvm.cpu())]:
-            tvm_out = get_tvm_output(
-                model,
-                [np_data, np_rois, np_batch_indicies],
-                target,
-                ctx,
-                output_dims,
-                output_dtype="float32",
-            )
-            tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05)
+        verify_with_ort_with_inputs(model, [np_data, np_rois, np_batch_indicies], output_dims)
 
     verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0)
     verify_roi_align((4, 4, 16, 32), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0)