import scipy
from tvm import relay
from tvm.relay import transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing
from tvm.contrib.nvcc import have_fp16
-def run_infer_type(expr):
- mod = relay.Module.from_expr(expr)
- mod = transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
def sigmoid(x):
one = np.ones_like(x)
import topi.testing
from tvm import relay
from tvm.relay import transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, run_infer_type
import topi
import topi.testing
-def run_infer_type(expr):
- mod = relay.Module.from_expr(expr)
- mod = transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
def test_checkpoint():
dtype = "float32"
import tvm
from tvm import relay
from tvm.relay import create_executor, transform
-from tvm.relay.testing import ctx_list, check_grad
+from tvm.relay.testing import ctx_list, check_grad, run_infer_type
-def run_infer_type(expr):
- mod = relay.Module.from_expr(expr)
- mod = transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
def test_zeros_ones():
for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
import numpy as np
from tvm import relay
from tvm.relay import transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing
-def run_infer_type(expr):
- mod = relay.Module.from_expr(expr)
- mod = transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
def test_binary_op():
def check_binary_op(opfunc, ref):
import tvm
from tvm import relay
from tvm.relay import transform
-from tvm.relay.testing import ctx_list
+from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing
-def run_infer_type(expr):
- mod = relay.Module.from_expr(expr)
- mod = transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
def test_resize_infer_type():
n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
from tvm.relay import Function, Call
from tvm.relay import analysis
from tvm.relay import transform as _transform
-from tvm.relay.testing import ctx_list
-
-
-def run_infer_type(expr):
- mod = relay.Module.from_expr(expr)
- mod = _transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
+from tvm.relay.testing import ctx_list, run_infer_type
def get_var_func():
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
-from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass
+from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass
from tvm.relay import create_executor
from tvm.relay import Function, transform
-def rand(dtype='float32', *shape):
- return tvm.nd.array(np.random.rand(*shape).astype(dtype))
-
-
def test_id():
x = relay.var("x", shape=[])
id = run_infer_type(relay.Function([x], x))