[PY] GraphRuntime: Update the tutorials to the module-based interface (#6482)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 16 Sep 2020 15:18:06 +0000 (08:18 -0700)
committerGitHub <noreply@github.com>
Wed, 16 Sep 2020 15:18:06 +0000 (08:18 -0700)
* [PY] GraphRuntime: Update the tutorials to the module-based interface.

Also added document about the encouraged usage.
In particular, we encourage the following usage.

lib = relay.build(...)
gmod = graph_runtime.GraphModule(lib["default"](ctx))

I have changed most of the tutorials and apps.
Some follow up PRs are needed to update some of the tests code.

* Fix VTA tutorials

32 files changed:
apps/benchmark/arm_cpu_imagenet_bench.py
apps/benchmark/gpu_imagenet_bench.py
apps/benchmark/mobile_gpu_imagenet_bench.py
apps/ios_rpc/tests/ios_rpc_mobilenet.py
python/tvm/contrib/graph_runtime.py
python/tvm/relay/backend/graph_runtime_factory.py
python/tvm/relay/frontend/common.py
python/tvm/relay/param_dict.py
python/tvm/relay/quantize/_calibrate.py
tests/python/contrib/test_coreml_codegen.py
tests/python/contrib/test_ethosn/infrastructure.py
tests/python/contrib/test_ethosn/test_topologies.py
tests/python/frontend/caffe2/test_forward.py
tests/python/frontend/coreml/test_forward.py
tests/python/frontend/darknet/test_forward.py
tests/python/frontend/keras/test_forward.py
tests/python/frontend/mxnet/test_forward.py
tests/python/frontend/pytorch/qnn_test.py
tests/python/frontend/tflite/test_forward.py
tests/python/relay/benchmarking/benchmark_vm.py
tests/python/relay/test_cpp_build_module.py
tests/python/relay/test_simplify_fc_transpose.py
tests/python/unittest/test_target_codegen_blob.py
tests/python/unittest/test_tir_transform_hoist_if.py
tutorials/autotvm/tune_relay_arm.py
tutorials/autotvm/tune_relay_cuda.py
tutorials/autotvm/tune_relay_mobile_gpu.py
tutorials/autotvm/tune_relay_x86.py
tutorials/get_started/relay_quick_start.py
vta/tutorials/autotvm/tune_relay_vta.py
vta/tutorials/frontend/deploy_classification.py
vta/tutorials/frontend/legacy/deploy_detection.py

index a4a88d8..fb58819 100644 (file)
@@ -40,7 +40,7 @@ def evaluate_network(network, target, target_host, repeat):
 
     print_progress("%-20s building..." % network)
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(net, target=target, target_host=target_host, params=params)
+        lib = relay.build(net, target=target, target_host=target_host, params=params)
 
     tmp = tempdir()
     if "android" in str(target):
@@ -58,10 +58,9 @@ def evaluate_network(network, target, target_host, repeat):
     remote.upload(tmp.relpath(filename))
 
     rlib = remote.load_module(filename)
-    module = runtime.create(graph, rlib, ctx)
+    module = runtime.GraphModule(rlib["default"](ctx))
     data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
     module.set_input("data", data_tvm)
-    module.set_input(**params)
 
     # evaluate
     print_progress("%-20s evaluating..." % network)
index a1c0cc6..b78476f 100644 (file)
@@ -34,14 +34,13 @@ def benchmark(network, target):
     net, params, input_shape, output_shape = get_network(network, batch_size=1)
 
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(net, target=target, params=params)
+        lib = relay.build(net, target=target, params=params)
 
     # create runtime
     ctx = tvm.context(str(target), 0)
-    module = runtime.create(graph, lib, ctx)
+    module = runtime.GraphModule(lib["default"](ctx))
     data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
     module.set_input("data", data_tvm)
-    module.set_input(**params)
 
     # evaluate
     ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=args.repeat)
index fa1af54..b57f602 100644 (file)
@@ -40,7 +40,7 @@ def evaluate_network(network, target, target_host, dtype, repeat):
 
     print_progress("%-20s building..." % network)
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(net, target=target, target_host=target_host, params=params)
+        lib = relay.build(net, target=target, target_host=target_host, params=params)
 
     tmp = tempdir()
     if "android" in str(target) or "android" in str(target_host):
@@ -58,10 +58,9 @@ def evaluate_network(network, target, target_host, dtype, repeat):
     remote.upload(tmp.relpath(filename))
 
     rlib = remote.load_module(filename)
-    module = runtime.create(graph, rlib, ctx)
+    module = runtime.GraphModule(rlib["default"](ctx))
     data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
     module.set_input("data", data_tvm)
-    module.set_input(**params)
 
     # evaluate
     print_progress("%-20s evaluating..." % network)
index 642c7da..daac680 100644 (file)
@@ -104,9 +104,7 @@ def test_mobilenet():
 
     def run(mod, target):
         with relay.build_config(opt_level=3):
-            graph, lib, _params = relay.build(
-                mod, target=target, target_host=target_host, params=params
-            )
+            lib = relay.build(mod, target=target, target_host=target_host, params=params)
         path_dso = temp.relpath("deploy.dylib")
         lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk)
         xcode.codesign(path_dso)
@@ -122,10 +120,9 @@ def test_mobilenet():
         else:
             ctx = remote.cpu(0)
         lib = remote.load_module("deploy.dylib")
-        m = graph_runtime.create(graph, lib, ctx)
+        m = graph_runtime.GraphModule(lib["default"](ctx))
 
         m.set_input("data", tvm.nd.array(image, ctx))
-        m.set_input(**_params)
         m.run()
         tvm_output = m.get_output(0)
         top1 = np.argmax(tvm_output.asnumpy()[0])
index de941af..a960e55 100644 (file)
@@ -47,6 +47,12 @@ def create(graph_json_str, libmod, ctx):
     -------
     graph_module : GraphModule
         Runtime graph module that can be used to execute the graph.
+
+    Note
+    ----
+    See also :py:class:`tvm.contrib.graph_runtime.GraphModule`
+    for examples to directly construct a GraphModule from an exported
+    relay compiled library.
     """
     assert isinstance(graph_json_str, string_types)
 
@@ -121,6 +127,27 @@ class GraphModule(object):
     ----------
     module : tvm.runtime.Module
         The internal tvm module that holds the actual graph functions.
+
+    Examples
+    --------
+
+    .. code-block:: python
+
+        import tvm
+        from tvm import relay
+        from tvm.contrib import graph_runtime
+
+        # build the library using graph runtime
+        lib = relay.build(...)
+        lib.export_library("compiled_lib.so")
+        # load it back as a runtime
+        lib:tvm.runtime.Module = tvm.runtime.load_module("compiled_lib.so")
+        # Call the library factory function for default and create
+        # a new runtime.Module, wrap with graph module.
+        gmod = graph_runtime.GraphModule(lib["default"](ctx))
+        # use the gmod
+        gmod.set_input("x", data)
+        gmod.run()
     """
 
     def __init__(self, module):
index a21a4a8..681a842 100644 (file)
@@ -74,7 +74,9 @@ class GraphRuntimeFactoryModule(object):
     def __iter__(self):
         warnings.warn(
             "legacy graph runtime behaviour of producing json / lib / params will be "
-            "removed in the next release ",
+            " removed in the next release."
+            " Please see documents of tvm.contrib.graph_runtime.GraphModule for the "
+            " new recommended usage.",
             DeprecationWarning,
             2,
         )
index b46cd27..e4d605a 100644 (file)
@@ -522,10 +522,9 @@ def infer_value(input_val, params, mod=None):
 
         func = _function.Function(analysis.free_vars(input_val), input_val)
         with tvm.transform.PassContext(opt_level=0):
-            graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
+            lib = tvm.relay.build(func, target="llvm", params=params)
         ctx = tvm.cpu(0)
-        m = graph_runtime.create(graph, lib, ctx)
-        m.set_input(**params)
+        m = graph_runtime.GraphModule(lib["default"](ctx))
         m.run()
         return m.get_output(0)
     except Exception:
index 2a01353..2d0398e 100644 (file)
@@ -44,14 +44,13 @@ def save_param_dict(params):
     --------
     .. code-block:: python
 
-       # compile and save the modules to file.
-       graph, lib, params = tvm.relay.build(func, target=target, params=params)
-       module = graph_runtime.create(graph, lib, tvm.gpu(0))
+       # set up the parameter dict
+       params = {"param0": arr0, "param1": arr1}
        # save the parameters as byte array
        param_bytes = tvm.relay.save_param_dict(params)
        # We can serialize the param_bytes and load it back later.
        # Pass in byte array to module to directly set parameters
-       module.load_params(param_bytes)
+       graph_runtime_mod.load_params(param_bytes)
     """
     args = []
     for k, v in params.items():
index c460dad..8461daa 100644 (file)
@@ -45,9 +45,8 @@ def _get_profile_runtime(mod):
         ctx = tvm.context(target)
 
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = _build_module.build(func, target=target)
-    runtime = graph_runtime.create(graph, lib, ctx)
-    runtime.set_input(**params)
+        lib = _build_module.build(func, target=target)
+    runtime = graph_runtime.GraphModule(lib["default"](ctx))
 
     return runtime
 
index cd51668..018aade 100644 (file)
@@ -102,8 +102,8 @@ def test_compile_and_run():
     tol = 1e-3
 
     with relay.build_config(opt_level=3):
-        json, lib, params = relay.build(_create_graph_annotated(), target=target)
-    m = tvm.contrib.graph_runtime.create(json, lib, ctx)
+        lib = relay.build(_create_graph_annotated(), target=target)
+    m = tvm.contrib.graph_runtime.GraphModule(lib["default"](ctx))
 
     shape = (10, 10)
     x_data = np.random.rand(*shape).astype("float32")
@@ -111,7 +111,6 @@ def test_compile_and_run():
 
     m.set_input("x", x_data)
     m.set_input("y", y_data)
-    m.set_input(**params)
     m.run()
     out = tvm.nd.empty(shape, ctx=ctx)
     out = m.get_output(0, out)
index 31ebb1a..0d5e1e5 100644 (file)
@@ -132,16 +132,15 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1):
             return relay.build(mod, params=params)
 
 
-def run(graph, lib, params, inputs, outputs, npu=True):
+def run(lib, inputs, outputs, npu=True):
     # Export and load lib to confirm this works
     lib_name = "mod.so"
     temp = util.tempdir()
     lib_path = temp.relpath(lib_name)
     lib.export_library(lib_path)
     lib = tvm.runtime.load_module(lib_path)
-    module = graph_runtime.create(graph, lib, tvm.cpu())
+    module = graph_runtime.GraphModule(lib["default"](tvm.cpu()))
     module.set_input(**inputs)
-    module.set_input(**params)
     module.run()
     out = [module.get_output(i) for i in range(outputs)]
     if not npu:
@@ -152,8 +151,8 @@ def run(graph, lib, params, inputs, outputs, npu=True):
 def build_and_run(
     mod, inputs, outputs, params, ctx=tvm.cpu(), npu=True, expected_host_ops=0, npu_partitions=1
 ):
-    graph, lib, params = build(mod, params, npu, expected_host_ops, npu_partitions)
-    return run(graph, lib, params, inputs, outputs, npu)
+    lib = build(mod, params, npu, expected_host_ops, npu_partitions)
+    return run(lib, inputs, outputs, npu)
 
 
 def verify(answers, atol, rtol=1e-07, verify_saturation=True):
index 89099db..f4e8beb 100644 (file)
@@ -243,7 +243,7 @@ def test_input_tuples():
             mod = tei.make_module(model, {})
         else:
             mod = tei.make_ethosn_partition(model)
-        graph, lib, params = tei.build(mod, {}, npu=False)
-        outputs.append(tei.run(graph, lib, {}, inputs, 1, npu=npu))
+        lib = tei.build(mod, {}, npu=False)
+        outputs.append(tei.run(lib, inputs, 1, npu=npu))
 
     tei.verify(outputs, 0)
index 75f9371..a45b86c 100644 (file)
@@ -40,13 +40,12 @@ def get_tvm_output(model, input_data, target, ctx, output_shape, output_dtype="f
         model.init_net, model.predict_net, shape_dict, dtype_dict
     )
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(mod, target, params=params)
+        lib = relay.build(mod, target, params=params)
 
-    m = graph_runtime.create(graph, lib, ctx)
+    m = graph_runtime.GraphModule(lib["default"](ctx))
 
     # set inputs
     m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
-    m.set_input(**params)
 
     # execute
     m.run()
index d808469..1d3f6c9 100644 (file)
@@ -36,11 +36,10 @@ def get_tvm_output(
     func, x, params, target, ctx, out_shape=(1, 1000), input_name="image", dtype="float32"
 ):
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(func, target, params=params)
-    m = graph_runtime.create(graph, lib, ctx)
+        lib = relay.build(func, target, params=params)
+    m = graph_runtime.GraphModule(lib["default"](ctx))
     # set inputs
     m.set_input(input_name, tvm.nd.array(x.astype(dtype)))
-    m.set_input(**params)
     m.run()
     # get outputs
     out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
@@ -87,11 +86,11 @@ def run_tvm_graph(
 
     mod, params = relay.frontend.from_coreml(coreml_model, shape_dict)
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(mod, target, params=params)
+        lib = relay.build(mod, target, params=params)
 
     from tvm.contrib import graph_runtime
 
-    m = graph_runtime.create(graph, lib, ctx)
+    m = graph_runtime.GraphModule(lib["default"](ctx))
     # set inputs
     if isinstance(input_data, list):
         for i, e in enumerate(input_name):
@@ -99,7 +98,6 @@ def run_tvm_graph(
     else:
         m.set_input(input_name, tvm.nd.array(input_data.astype(input_data.dtype)))
 
-    m.set_input(**params)
     # execute
     m.run()
     # get outputs
index c7bc775..9e21a86 100644 (file)
@@ -61,17 +61,16 @@ def _get_tvm_output(net, data, build_dtype="float32", states=None):
     mod, params = relay.frontend.from_darknet(net, data.shape, dtype)
     target = "llvm"
     shape_dict = {"data": data.shape}
-    graph, library, params = relay.build(mod, target, params=params)
+    lib = relay.build(mod, target, params=params)
 
     # Execute on TVM
     ctx = tvm.cpu(0)
-    m = graph_runtime.create(graph, library, ctx)
+    m = graph_runtime.GraphModule(lib["default"](ctx))
     # set inputs
     m.set_input("data", tvm.nd.array(data.astype(dtype)))
     if states:
         for name in states.keys():
             m.set_input(name, tvm.nd.array(states[name].astype(dtype)))
-    m.set_input(**params)
     m.run()
     # get outputs
     tvm_out = []
index 3ba7a03..251c887 100644 (file)
@@ -88,11 +88,10 @@ def verify_keras_frontend(keras_model, need_transpose=True, layout="NCHW"):
         shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
         mod, params = relay.frontend.from_keras(keras_model, shape_dict, layout=layout)
         with tvm.transform.PassContext(opt_level=2):
-            graph, lib, params = relay.build(mod, target, params=params)
-        m = graph_runtime.create(graph, lib, ctx)
+            lib = relay.build(mod, target, params=params)
+        m = graph_runtime.GraphModule(lib["default"](ctx))
         for name, x in zip(keras_model.input_names, xs):
             m.set_input(name, tvm.nd.array(x.astype(dtype)))
-        m.set_input(**params)
         m.run()
         return [m.get_output(i).asnumpy() for i in range(m.get_num_outputs())]
 
index 639f8e2..52c5db8 100644 (file)
@@ -77,11 +77,10 @@ def verify_mxnet_frontend_impl(
                 symbol, shape_dict, arg_params=args, aux_params=auxs
             )
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build(mod, target, params=params)
-        m = graph_runtime.create(graph, lib, ctx)
+            lib = relay.build(mod, target, params=params)
+        m = graph_runtime.GraphModule(lib["default"](ctx))
         # set inputs
         m.set_input("data", tvm.nd.array(x.astype(dtype)))
-        m.set_input(**params)
         m.run()
         # get outputs
         out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
index f34dcef..ac24a1f 100644 (file)
@@ -45,10 +45,9 @@ def get_tvm_runtime(script_module, input_name, ishape):
     with tvm.transform.PassContext(opt_level=3):
         # test on only cpu for now, torch cannot run quant models on cuda
         # also not to make CI too slow
-        json, lib, params = relay.build(mod, target="llvm", params=params)
+        lib = relay.build(mod, target="llvm", params=params)
 
-    runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0))
-    runtime.set_input(**params)
+    runtime = tvm.contrib.graph_runtime.GraphModule(lib["default"](tvm.cpu(0)))
     return runtime
 
 
index e8b225c..7d67427 100644 (file)
@@ -199,17 +199,15 @@ def run_tvm_graph(
         return vmobj_to_list(result)
     else:
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build(mod, target, params=params)
+            lib = relay.build(mod, target, params=params)
 
         ctx = tvm.context(target, 0)
         from tvm.contrib import graph_runtime
 
-        m = graph_runtime.create(graph, lib, ctx)
+        m = graph_runtime.GraphModule(lib["default"](ctx))
         # set inputs
         for i, e in enumerate(input_node):
             m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
-
-        m.set_input(**params)
         # execute
         m.run()
         # get outputs
index c91a554..02e6ae6 100644 (file)
@@ -40,12 +40,11 @@ def benchmark_execution(
         mod, data, params, target, ctx, dtype="float32", number=2, repeat=20
     ):
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build(mod, target, params=params)
+            lib = relay.build(mod, target, params=params)
 
-        m = graph_runtime.create(graph, lib, ctx)
+        m = graph_runtime.GraphModule(lib["default"](ctx))
         # set inputs
         m.set_input("data", data)
-        m.set_input(**params)
         m.run()
         out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
 
index fe44eb2..67f0621 100644 (file)
@@ -44,13 +44,12 @@ def test_basic_build():
     func_in_mod = mod["main"]
     assert mod["main"] == func_in_mod, "cannot compare function to itself"
 
-    g_json, mmod, params = relay.build(mod, targets, "llvm", params=params)
+    lib = relay.build(mod, targets, "llvm", params=params)
     assert mod["main"] == func_in_mod, "relay.build changed module in-place"
 
     # test
-    rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
+    rt = tvm.contrib.graph_runtime.GraphModule(lib["default"](ctx))
     rt.set_input("a", A)
-    rt.load_params(relay.save_param_dict(params))
     rt.run()
     out = rt.get_output(0)
 
index 26fc356..d5d195d 100644 (file)
@@ -29,16 +29,15 @@ from tvm.relay.data_dep_optimization import simplify_fc_transpose
 
 def run_func(func, params, x):
     with tvm.transform.PassContext(opt_level=3):
-        graph, lib, new_params = relay.build(func, "llvm", params=params)
+        lib = relay.build(func, "llvm", params=params)
 
     from tvm.contrib import graph_runtime
 
     ctx = tvm.cpu(0)
     dtype = "float32"
-    m = graph_runtime.create(graph, lib, ctx)
+    m = graph_runtime.GraphModule(lib["default"](ctx))
     # set inputs
     m.set_input("data", tvm.nd.array(x.astype(dtype)))
-    m.set_input(**new_params)
     # execute
     m.run()
     # get outputs
index dedabf4..8189d13 100644 (file)
@@ -37,38 +37,28 @@ def test_synthetic():
     def verify(data):
         mod, params = relay.testing.synthetic.get_workload(input_shape=input_shape)
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params)
+            lib = relay.build_module.build(mod, "llvm", params=params)
         ctx = tvm.cpu()
-        module = graph_runtime.create(graph, lib, ctx)
+        module = graph_runtime.GraphModule(lib["default"](ctx))
         module.set_input("data", data)
-        module.set_input(**graph_params)
         module.run()
         out = module.get_output(0).asnumpy()
         return out
 
     synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload(input_shape=input_shape)
     with tvm.transform.PassContext(opt_level=3):
-        graph, synthetic_gpu_lib, graph_params = relay.build_module.build(
-            synthetic_mod, "cuda", params=synthetic_params
-        )
+        synthetic_gpu_lib = relay.build_module.build(synthetic_mod, "cuda", params=synthetic_params)
 
     from tvm.contrib import util
 
     temp = util.tempdir()
     path_lib = temp.relpath("deploy_lib.so")
     synthetic_gpu_lib.export_library(path_lib)
-    with open(temp.relpath("deploy_graph.json"), "w") as fo:
-        fo.write(graph)
-    with open(temp.relpath("deploy_param.params"), "wb") as fo:
-        fo.write(relay.save_param_dict(graph_params))
 
     loaded_lib = tvm.runtime.load_module(path_lib)
-    loaded_json = open(temp.relpath("deploy_graph.json")).read()
-    loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
     data = np.random.uniform(-1, 1, size=input_shape).astype("float32")
     ctx = tvm.gpu()
-    module = graph_runtime.create(loaded_json, loaded_lib, ctx)
-    module.load_params(loaded_params)
+    module = graph_runtime.GraphModule(loaded_lib["default"](ctx))
     module.set_input("data", data)
     module.run()
     out = module.get_output(0).asnumpy()
index 2cbb665..77ab38d 100644 (file)
@@ -764,12 +764,11 @@ def test_hoisting_op_conv():
     params = {"w": tvm.nd.array(kernel)}
     for target, ctx in enabled_targets():
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
-            m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+            lib = relay.build_module.build(mod, target=target, params=params)
+            m = tvm.contrib.graph_runtime.GraphModule(lib["default"](ctx))
             x = np.random.uniform(size=dshape)
             data_tvm = tvm.nd.array(data)
             m.set_input("x", data_tvm)
-            m.set_input(**params)
             m.run()
             e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
             t1 = e(data_tvm).results
@@ -779,8 +778,8 @@ def test_hoisting_op_conv():
         with tvm.transform.PassContext(
             opt_level=3, config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
         ):
-            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
-            m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+            lib = relay.build_module.build(mod, target=target, params=params)
+            m = tvm.contrib.graph_runtime.GraphModule(lib["default"](ctx))
             x = np.random.uniform(size=dshape)
             data_tvm = tvm.nd.array(data)
             m.set_input("x", data_tvm)
index 4cfde72..a336870 100644 (file)
@@ -326,7 +326,7 @@ def tune_and_evaluate(tuning_opt):
     with autotvm.apply_history_best(log_file):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
+            lib = relay.build_module.build(mod, target=target, params=params)
 
         # export library
         tmp = tempdir()
@@ -347,10 +347,9 @@ def tune_and_evaluate(tuning_opt):
 
         # upload parameters to device
         ctx = remote.context(str(target), 0)
-        module = runtime.create(graph, rlib, ctx)
+        module = runtime.GraphModule(rlib["default"](ctx))
         data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
         module.set_input("data", data_tvm)
-        module.set_input(**params)
 
         # evaluate
         print("Evaluate inference time cost...")
index 64be5eb..32ee266 100644 (file)
@@ -233,7 +233,7 @@ def tune_and_evaluate(tuning_opt):
     with autotvm.apply_history_best(log_file):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
+            lib = relay.build_module.build(mod, target=target, params=params)
 
         # export library
         tmp = tempdir()
@@ -242,10 +242,9 @@ def tune_and_evaluate(tuning_opt):
 
         # load parameters
         ctx = tvm.context(str(target), 0)
-        module = runtime.create(graph, lib, ctx)
+        module = runtime.GraphModule(lib["default"](ctx))
         data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
         module.set_input("data", data_tvm)
-        module.set_input(**params)
 
         # evaluate
         print("Evaluate inference time cost...")
index 1fa2326..19fa601 100644 (file)
@@ -325,7 +325,7 @@ def tune_and_evaluate(tuning_opt):
     with autotvm.apply_history_best(log_file):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build_module.build(
+            lib = relay.build_module.build(
                 mod, target=target, params=params, target_host=target_host
             )
         # export library
@@ -347,10 +347,9 @@ def tune_and_evaluate(tuning_opt):
 
         # upload parameters to device
         ctx = remote.context(str(target), 0)
-        module = runtime.create(graph, rlib, ctx)
+        module = runtime.GraphModule(rlib["default"](ctx))
         data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
         module.set_input("data", data_tvm)
-        module.set_input(**params)
 
         # evaluate
         print("Evaluate inference time cost...")
index 8816824..1dd947f 100644 (file)
@@ -208,14 +208,13 @@ def tune_and_evaluate(tuning_opt):
     with autotvm.apply_graph_best(graph_opt_sch_file):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
-            graph, lib, params = relay.build_module.build(mod, target=target, params=params)
+            lib = relay.build_module.build(mod, target=target, params=params)
 
         # upload parameters to device
         ctx = tvm.cpu()
         data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
-        module = runtime.create(graph, lib, ctx)
+        module = runtime.GraphModule(lib["default"](ctx))
         module.set_input(input_name, data_tvm)
-        module.set_input(**params)
 
         # evaluate
         print("Evaluate inference time cost...")
index 437a22f..5c7f933 100644 (file)
@@ -98,7 +98,7 @@ print(mod.astext(show_meta_data=False))
 opt_level = 3
 target = tvm.target.cuda()
 with tvm.transform.PassContext(opt_level=opt_level):
-    graph, lib, params = relay.build(mod, target, params=params)
+    lib = relay.build(mod, target, params=params)
 
 #####################################################################
 # Run the generate library
@@ -109,10 +109,9 @@ with tvm.transform.PassContext(opt_level=opt_level):
 ctx = tvm.gpu()
 data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
 # create module
-module = graph_runtime.create(graph, lib, ctx)
+module = graph_runtime.GraphModule(lib["default"](ctx))
 # set input and parameters
 module.set_input("data", data)
-module.set_input(**params)
 # run
 module.run()
 # get output
@@ -135,22 +134,15 @@ from tvm.contrib import util
 temp = util.tempdir()
 path_lib = temp.relpath("deploy_lib.tar")
 lib.export_library(path_lib)
-with open(temp.relpath("deploy_graph.json"), "w") as fo:
-    fo.write(graph)
-with open(temp.relpath("deploy_param.params"), "wb") as fo:
-    fo.write(relay.save_param_dict(params))
 print(temp.listdir())
 
 ####################################################
 
 # load the module back.
-loaded_json = open(temp.relpath("deploy_graph.json")).read()
 loaded_lib = tvm.runtime.load_module(path_lib)
-loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
 input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
 
-module = graph_runtime.create(loaded_json, loaded_lib, ctx)
-module.load_params(loaded_params)
+module = graph_runtime.GraphModule(loaded_lib["default"](ctx))
 module.run(data=input_data)
 out_deploy = module.get_output(0).asnumpy()
 
index d1a2e85..41fd04e 100644 (file)
@@ -413,12 +413,12 @@ def tune_and_evaluate(tuning_opt):
         print("Compile...")
         if target.device_name != "vta":
             with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
-                graph, lib, params = relay.build(
+                lib = relay.build(
                     relay_prog, target=target, params=params, target_host=env.target_host
                 )
         else:
             with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
-                graph, lib, params = relay.build(
+                lib = relay.build(
                     relay_prog, target=target, params=params, target_host=env.target_host
                 )
 
@@ -431,11 +431,10 @@ def tune_and_evaluate(tuning_opt):
 
         # Generate the graph runtime
         ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
-        m = graph_runtime.create(graph, lib, ctx)
+        m = graph_runtime.GraphModule(lib["default"](ctx))
 
         # upload parameters to device
         image = tvm.nd.array((np.random.uniform(size=(1, 3, 224, 224))).astype("float32"))
-        m.set_input(**params)
         m.set_input("data", image)
 
         # evaluate
index 74c7412..04716ce 100644 (file)
@@ -197,9 +197,7 @@ with autotvm.tophub.context(target):
             )
     else:
         with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
-            graph, lib, params = relay.build(
-                relay_prog, target=target, params=params, target_host=env.target_host
-            )
+            lib = relay.build(relay_prog, target=target, params=params, target_host=env.target_host)
 
     # Measure Relay build time
     build_time = time.time() - build_start
@@ -207,12 +205,12 @@ with autotvm.tophub.context(target):
 
     # Send the inference library over to the remote RPC server
     temp = util.tempdir()
-    lib.save(temp.relpath("graphlib.o"))
-    remote.upload(temp.relpath("graphlib.o"))
-    lib = remote.load_module("graphlib.o")
+    lib.export_library(temp.relpath("graphlib.tar"))
+    remote.upload(temp.relpath("graphlib.tar"))
+    lib = remote.load_module("graphlib.tar")
 
     # Graph runtime
-    m = graph_runtime.create(graph, lib, ctx)
+    m = graph_runtime.GraphModule(lib["default"](ctx))
 
 ######################################################################
 # Perform image classification inference
@@ -243,7 +241,6 @@ image = image[np.newaxis, :]
 image = np.repeat(image, env.BATCH, axis=0)
 
 # Set the network parameters and inputs
-m.set_input(**params)
 m.set_input("data", image)
 
 # Perform inference and gather execution statistics
index f6fd462..010ee31 100644 (file)
@@ -234,9 +234,7 @@ with autotvm.tophub.context(target):
 
     # Compile Relay program with AlterOpLayout disabled
     with vta.build_config(disabled_pass={"AlterOpLayout"}):
-        graph, lib, params = relay.build(
-            mod, target=target, params=params, target_host=env.target_host
-        )
+        lib = relay.build(mod, target=target, params=params, target_host=env.target_host)
 
     # Measure Relay build time
     build_time = time.time() - build_start
@@ -244,12 +242,12 @@ with autotvm.tophub.context(target):
 
     # Send the inference library over to the remote RPC server
     temp = util.tempdir()
-    lib.save(temp.relpath("graphlib.o"))
-    remote.upload(temp.relpath("graphlib.o"))
-    lib = remote.load_module("graphlib.o")
+    lib.export_library(temp.relpath("graphlib.tar"))
+    remote.upload(temp.relpath("graphlib.tar"))
+    lib = remote.load_module("graphlib.tar")
 
     # Graph runtime
-    m = graph_runtime.create(graph, lib, ctx)
+    m = graph_runtime.GraphModule(lib["default"](ctx))
 
 ####################################
 # Perform image detection inference.
@@ -271,7 +269,6 @@ data = np.repeat(data, env.BATCH, axis=0)
 
 # Set the network parameters and inputs
 m.set_input("data", data)
-m.set_input(**params)
 
 # Perform inference and gather execution statistics
 # More on: :py:method:`tvm.runtime.Module.time_evaluator`