Remove run_infer_type duplicates (#4766)
authorAlexander Pivovarov <pivovaa@amazon.com>
Wed, 22 Jan 2020 06:30:02 +0000 (22:30 -0800)
committermasahi <masahi129@gmail.com>
Wed, 22 Jan 2020 06:30:02 +0000 (15:30 +0900)
tests/python/relay/test_op_level1.py
tests/python/relay/test_op_level10.py
tests/python/relay/test_op_level3.py
tests/python/relay/test_op_level4.py
tests/python/relay/test_op_level5.py
tests/python/relay/test_pass_manager.py
tests/python/relay/test_pass_to_cps.py

index f73826e..194b095 100644 (file)
@@ -20,15 +20,10 @@ import tvm
 import scipy
 from tvm import relay
 from tvm.relay import transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, run_infer_type
 import topi.testing
 from tvm.contrib.nvcc import have_fp16
 
-def run_infer_type(expr):
-    mod = relay.Module.from_expr(expr)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
 
 def sigmoid(x):
     one = np.ones_like(x)
index bb1d346..6a6f21d 100644 (file)
@@ -21,15 +21,10 @@ import tvm
 import topi.testing
 from tvm import relay
 from tvm.relay import transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, run_infer_type
 import topi
 import topi.testing
 
-def run_infer_type(expr):
-    mod = relay.Module.from_expr(expr)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_checkpoint():
     dtype = "float32"
index 13f17ca..9c5dfac 100644 (file)
@@ -21,13 +21,8 @@ import pytest
 import tvm
 from tvm import relay
 from tvm.relay import create_executor, transform
-from tvm.relay.testing import ctx_list, check_grad
+from tvm.relay.testing import ctx_list, check_grad, run_infer_type
 
-def run_infer_type(expr):
-    mod = relay.Module.from_expr(expr)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_zeros_ones():
     for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
index 2b25d6a..0243adc 100644 (file)
@@ -18,14 +18,9 @@ import tvm
 import numpy as np
 from tvm import relay
 from tvm.relay import transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, run_infer_type
 import topi.testing
 
-def run_infer_type(expr):
-    mod = relay.Module.from_expr(expr)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_binary_op():
     def check_binary_op(opfunc, ref):
index d4abf3d..eb21f33 100644 (file)
@@ -21,14 +21,9 @@ import numpy as np
 import tvm
 from tvm import relay
 from tvm.relay import transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, run_infer_type
 import topi.testing
 
-def run_infer_type(expr):
-    mod = relay.Module.from_expr(expr)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_resize_infer_type():
     n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
index 3fb8508..e02e917 100644 (file)
@@ -24,14 +24,7 @@ from tvm.relay import ExprFunctor
 from tvm.relay import Function, Call
 from tvm.relay import analysis
 from tvm.relay import transform as _transform
-from tvm.relay.testing import ctx_list
-
-
-def run_infer_type(expr):
-    mod = relay.Module.from_expr(expr)
-    mod = _transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
+from tvm.relay.testing import ctx_list, run_infer_type
 
 
 def get_var_func():
index 045c92c..1d09c0d 100644 (file)
@@ -21,15 +21,11 @@ from tvm.relay.analysis import alpha_equal, detect_feature
 from tvm.relay.transform import to_cps, un_cps
 from tvm.relay.feature import Feature
 from tvm.relay.prelude import Prelude
-from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass
+from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass
 from tvm.relay import create_executor
 from tvm.relay import Function, transform
 
 
-def rand(dtype='float32', *shape):
-    return tvm.nd.array(np.random.rand(*shape).astype(dtype))
-
-
 def test_id():
     x = relay.var("x", shape=[])
     id = run_infer_type(relay.Function([x], x))