--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_runtime_factory.cc
+ * \brief Graph runtime factory implementations
+ */
+
+#include "./graph_runtime_factory.h"
+
+#include <tvm/node/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include <iterator>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+
+GraphRuntimeFactory::GraphRuntimeFactory(
+ const std::string& graph_json,
+ const std::unordered_map<std::string, tvm::runtime::NDArray>& params,
+ const std::string& module_name) {
+ graph_json_ = graph_json;
+ params_ = params;
+ module_name_ = module_name;
+}
+
+PackedFunc GraphRuntimeFactory::GetFunction(
+ const std::string& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
+ if (name == module_name_) {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ std::vector<TVMContext> contexts;
+ for (int i = 0; i < args.num_args; ++i) {
+ contexts.emplace_back(args[i].operator TVMContext());
+ }
+ *rv = this->RuntimeCreate(contexts);
+ });
+ } else if (name == "debug_create") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ CHECK_GE(args.size(), 2);
+ std::string module_name = args[0].operator String();
+ CHECK(module_name == module_name_) << "Currently we only support single model for now.";
+ std::vector<TVMContext> contexts;
+ for (int i = 1; i < args.num_args; ++i) {
+ contexts.emplace_back(args[i].operator TVMContext());
+ }
+ *rv = this->DebugRuntimeCreate(contexts);
+ });
+ } else if (name == "remove_params") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ std::unordered_map<std::string, tvm::runtime::NDArray> empty_params{};
+ auto exec =
+ make_object<GraphRuntimeFactory>(this->graph_json_, empty_params, this->module_name_);
+ exec->Import(this->imports_[0]);
+ *rv = Module(exec);
+ });
+ } else {
+ return PackedFunc();
+ }
+}
+
+void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) {
+ stream->Write(graph_json_);
+ std::vector<std::string> names;
+ std::vector<DLTensor*> arrays;
+ for (const auto& v : params_) {
+ names.emplace_back(v.first);
+ arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->()));
+ }
+ uint64_t sz = arrays.size();
+ CHECK(sz == names.size());
+ stream->Write(sz);
+ stream->Write(names);
+ for (size_t i = 0; i < sz; ++i) {
+ tvm::runtime::SaveDLTensor(stream, arrays[i]);
+ }
+ stream->Write(module_name_);
+}
+
+Module GraphRuntimeFactory::RuntimeCreate(const std::vector<TVMContext>& ctxs) {
+ auto exec = make_object<GraphRuntime>();
+ exec->Init(this->graph_json_, this->imports_[0], ctxs);
+ // set params
+ SetParams(exec.get(), this->params_);
+ return Module(exec);
+}
+
+Module GraphRuntimeFactory::DebugRuntimeCreate(const std::vector<TVMContext>& ctxs) {
+ const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_debug.create");
+ CHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_debug.create in registry. "
+ "Do you enable debug graph runtime build?";
+ // Debug runtime create packed function will call GetAllContexs, so we unpack the ctxs.
+ std::vector<int> unpacked_ctxs;
+ for (const auto& ctx : ctxs) {
+ unpacked_ctxs.emplace_back(ctx.device_type);
+ unpacked_ctxs.emplace_back(ctx.device_id);
+ }
+ size_t args_size = unpacked_ctxs.size() + 2;
+ std::vector<TVMValue> values(args_size);
+ std::vector<int> codes(args_size);
+ runtime::TVMArgsSetter setter(values.data(), codes.data());
+ setter(0, this->graph_json_);
+ setter(1, this->imports_[0]);
+ for (size_t i = 0; i < unpacked_ctxs.size(); ++i) {
+ setter(i + 2, unpacked_ctxs[i]);
+ }
+ TVMRetValue rv;
+ pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv);
+ Module mod = rv.operator Module();
+ // debug graph runtime is one child class of graph runtime.
+ SetParams(const_cast<GraphRuntime*>(mod.as<GraphRuntime>()), this->params_);
+ return mod;
+}
+
+Module GraphRuntimeFactoryModuleLoadBinary(void* strm) {
+ dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+ std::string graph_json;
+ std::unordered_map<std::string, tvm::runtime::NDArray> params;
+ std::string module_name;
+ CHECK(stream->Read(&graph_json));
+ uint64_t sz;
+ CHECK(stream->Read(&sz));
+ std::vector<std::string> names;
+ CHECK(stream->Read(&names));
+ CHECK(sz == names.size());
+ for (size_t i = 0; i < sz; ++i) {
+ tvm::runtime::NDArray temp;
+ temp.Load(stream);
+ params[names[i]] = temp;
+ }
+ CHECK(stream->Read(&module_name));
+ auto exec = make_object<GraphRuntimeFactory>(graph_json, params, module_name);
+ return Module(exec);
+}
+
+TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) {
+ CHECK_GE(args.num_args, 3) << "The expected number of arguments for "
+ "graph_runtime_factory.create needs at least 3, "
+ "but it has "
+ << args.num_args;
+ // The argument order is graph_json, module, module_name, params.
+ CHECK_EQ((args.size() - 3) % 2, 0);
+ std::unordered_map<std::string, tvm::runtime::NDArray> params;
+ for (size_t i = 3; i < static_cast<size_t>(args.size()); i += 2) {
+ std::string name = args[i].operator String();
+ params[name] = args[i + 1].operator tvm::runtime::NDArray();
+ }
+ auto exec = make_object<GraphRuntimeFactory>(args[0], params, args[2]);
+ exec->Import(args[1]);
+ *rv = Module(exec);
+});
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphRuntimeFactory")
+ .set_body_typed(GraphRuntimeFactoryModuleLoadBinary);
+
+} // namespace runtime
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+from tvm import relay
+from tvm.relay import testing
+import tvm
+from tvm.contrib import graph_runtime
+from tvm.contrib.debugger import debug_runtime
+
+def verify(data):
+ if not tvm.runtime.enabled("llvm"):
+ print("Skip because llvm is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params)
+
+ ctx = tvm.cpu()
+ module = graph_runtime.create(graph, lib, ctx)
+ module.set_input("data", data)
+ module.set_input(**graph_params)
+ module.run()
+ out = module.get_output(0).asnumpy()
+
+ return out
+
+def test_legacy_compatibility():
+ if not tvm.runtime.enabled("llvm"):
+ print("Skip because llvm is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = tvm.cpu()
+ module = graph_runtime.create(graph, lib, ctx)
+ module.set_input("data", data)
+ module.set_input(**graph_params)
+ module.run()
+ out = module.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+def test_cpu():
+ if not tvm.runtime.enabled("llvm"):
+ print("Skip because llvm is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ # raw api
+ ctx = tvm.cpu()
+ gmod = complied_graph_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ set_input("data", tvm.nd.array(data))
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx))
+ gmod.set_input("data", data)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+def test_gpu():
+ if not tvm.runtime.enabled("cuda"):
+ print("Skip because cuda is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = tvm.gpu()
+
+ # raw api
+ gmod = complied_graph_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ set_input("data", tvm.nd.array(data))
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx))
+ gmod.set_input("data", data)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+def test_mod_export():
+ def verify_cpu_export(obj_format):
+ if not tvm.runtime.enabled("llvm"):
+ print("Skip because llvm is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
+
+ from tvm.contrib import util
+ temp = util.tempdir()
+ if obj_format == ".so":
+ file_name = "deploy_lib.so"
+ else:
+ assert obj_format == ".tar"
+ file_name = "deploy_lib.tar"
+ path_lib = temp.relpath(file_name)
+ complied_graph_lib.export_library(path_lib)
+ loaded_lib = tvm.runtime.load_module(path_lib)
+ ctx = tvm.cpu(0)
+ gmod = loaded_lib['default'](ctx)
+
+ # raw api
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ set_input("data", tvm.nd.array(data))
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+ gmod.set_input("data", data)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ def verify_gpu_export(obj_format):
+ if not tvm.runtime.enabled("cuda"):
+ print("Skip because cuda is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
+
+ from tvm.contrib import util
+ temp = util.tempdir()
+ if obj_format == ".so":
+ file_name = "deploy_lib.so"
+ else:
+ assert obj_format == ".tar"
+ file_name = "deploy_lib.tar"
+ path_lib = temp.relpath(file_name)
+ complied_graph_lib.export_library(path_lib)
+ loaded_lib = tvm.runtime.load_module(path_lib)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = tvm.gpu()
+
+ # raw api
+ gmod = loaded_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ set_input("data", tvm.nd.array(data))
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+ gmod.set_input("data", data)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ def verify_rpc_cpu_export(obj_format):
+ if not tvm.runtime.enabled("llvm"):
+ print("Skip because llvm is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
+
+ from tvm.contrib import util
+ temp = util.tempdir()
+ if obj_format == ".so":
+ file_name = "deploy_lib.so"
+ else:
+ assert obj_format == ".tar"
+ file_name = "deploy_lib.tar"
+ path_lib = temp.relpath(file_name)
+ complied_graph_lib.export_library(path_lib)
+
+ from tvm import rpc
+ server = rpc.Server("localhost", use_popen=True)
+ remote = rpc.connect(server.host, server.port)
+ remote.upload(path_lib)
+ loaded_lib = remote.load_module(path_lib)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = remote.cpu()
+
+ # raw api
+ gmod = loaded_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ set_input("data", tvm.nd.array(data, ctx=ctx))
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+ gmod.set_input("data", data)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ def verify_rpc_gpu_export(obj_format):
+ if not tvm.runtime.enabled("cuda"):
+ print("Skip because cuda is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
+
+ from tvm.contrib import util
+ temp = util.tempdir()
+ if obj_format == ".so":
+ file_name = "deploy_lib.so"
+ else:
+ assert obj_format == ".tar"
+ file_name = "deploy_lib.tar"
+ path_lib = temp.relpath(file_name)
+ complied_graph_lib.export_library(path_lib)
+
+ from tvm import rpc
+ server = rpc.Server("localhost", use_popen=True)
+ remote = rpc.connect(server.host, server.port)
+ remote.upload(path_lib)
+ loaded_lib = remote.load_module(path_lib)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = remote.gpu()
+
+ # raw api
+ gmod = loaded_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ set_input("data", tvm.nd.array(data, ctx=ctx))
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+ gmod.set_input("data", data)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ for obj_format in [".so", ".tar"]:
+ verify_cpu_export(obj_format)
+ verify_gpu_export(obj_format)
+ verify_rpc_cpu_export(obj_format)
+ verify_rpc_gpu_export(obj_format)
+
+def test_remove_package_params():
+ def verify_cpu_remove_package_params(obj_format):
+ if not tvm.runtime.enabled("llvm"):
+ print("Skip because llvm is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
+
+ from tvm.contrib import util
+ temp = util.tempdir()
+ if obj_format == ".so":
+ file_name = "deploy_lib.so"
+ else:
+ assert obj_format == ".tar"
+ file_name = "deploy_lib.tar"
+ path_lib = temp.relpath(file_name)
+ complied_graph_lib_no_params = complied_graph_lib["remove_params"]()
+ complied_graph_lib_no_params.export_library(path_lib)
+ with open(temp.relpath("deploy_param.params"), "wb") as fo:
+ fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
+ loaded_lib = tvm.runtime.load_module(path_lib)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = tvm.cpu(0)
+
+ # raw api
+ gmod = loaded_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ load_params = gmod["load_params"]
+ loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
+ set_input("data", tvm.nd.array(data))
+ load_params(loaded_params)
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+ loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
+ gmod.set_input("data", data)
+ gmod.load_params(loaded_params)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ def verify_gpu_remove_package_params(obj_format):
+ if not tvm.runtime.enabled("cuda"):
+ print("Skip because cuda is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
+
+ from tvm.contrib import util
+ temp = util.tempdir()
+ if obj_format == ".so":
+ file_name = "deploy_lib.so"
+ else:
+ assert obj_format == ".tar"
+ file_name = "deploy_lib.tar"
+ path_lib = temp.relpath(file_name)
+ complied_graph_lib_no_params = complied_graph_lib["remove_params"]()
+ complied_graph_lib_no_params.export_library(path_lib)
+ with open(temp.relpath("deploy_param.params"), "wb") as fo:
+ fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
+ loaded_lib = tvm.runtime.load_module(path_lib)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = tvm.gpu(0)
+
+ # raw api
+ gmod = loaded_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ load_params = gmod["load_params"]
+ loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
+ set_input("data", tvm.nd.array(data))
+ load_params(loaded_params)
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+ loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
+ gmod.set_input("data", data)
+ gmod.load_params(loaded_params)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ def verify_rpc_cpu_remove_package_params(obj_format):
+ if not tvm.runtime.enabled("llvm"):
+ print("Skip because llvm is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
+
+ from tvm.contrib import util
+ temp = util.tempdir()
+ if obj_format == ".so":
+ file_name = "deploy_lib.so"
+ else:
+ assert obj_format == ".tar"
+ file_name = "deploy_lib.tar"
+ path_lib = temp.relpath(file_name)
+ complied_graph_lib_no_params = complied_graph_lib["remove_params"]()
+ complied_graph_lib_no_params.export_library(path_lib)
+ path_params = temp.relpath("deploy_param.params")
+ with open(path_params, "wb") as fo:
+ fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
+
+ from tvm import rpc
+ server = rpc.Server("localhost", use_popen=True)
+ remote = rpc.connect(server.host, server.port)
+ remote.upload(path_lib)
+ loaded_lib = remote.load_module(path_lib)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = remote.cpu()
+
+ # raw api
+ gmod = loaded_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ load_params = gmod["load_params"]
+ loaded_params = bytearray(open(path_params, "rb").read())
+ set_input("data", tvm.nd.array(data, ctx=ctx))
+ load_params(loaded_params)
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+ loaded_params = bytearray(open(path_params, "rb").read())
+ gmod.set_input("data", data)
+ gmod.load_params(loaded_params)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ def verify_rpc_gpu_remove_package_params(obj_format):
+ if not tvm.runtime.enabled("cuda"):
+ print("Skip because cuda is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
+
+ from tvm.contrib import util
+ temp = util.tempdir()
+ if obj_format == ".so":
+ file_name = "deploy_lib.so"
+ else:
+ assert obj_format == ".tar"
+ file_name = "deploy_lib.tar"
+ path_lib = temp.relpath(file_name)
+ complied_graph_lib_no_params = complied_graph_lib["remove_params"]()
+ complied_graph_lib_no_params.export_library(path_lib)
+ path_params = temp.relpath("deploy_param.params")
+ with open(path_params, "wb") as fo:
+ fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
+
+ from tvm import rpc
+ server = rpc.Server("localhost", use_popen=True)
+ remote = rpc.connect(server.host, server.port)
+ remote.upload(path_lib)
+ loaded_lib = remote.load_module(path_lib)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+ ctx = remote.gpu()
+
+ # raw api
+ gmod = loaded_lib['default'](ctx)
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ load_params = gmod["load_params"]
+ loaded_params = bytearray(open(path_params, "rb").read())
+ set_input("data", tvm.nd.array(data, ctx=ctx))
+ load_params(loaded_params)
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # graph runtime wrapper
+ gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx))
+ loaded_params = bytearray(open(path_params, "rb").read())
+ gmod.set_input("data", data)
+ gmod.load_params(loaded_params)
+ gmod.run()
+ out = gmod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ for obj_format in [".so", ".tar"]:
+ verify_cpu_remove_package_params(obj_format)
+ verify_gpu_remove_package_params(obj_format)
+ verify_rpc_cpu_remove_package_params(obj_format)
+ verify_rpc_gpu_remove_package_params(obj_format)
+
+def test_debug_graph_runtime():
+ if not tvm.runtime.enabled("llvm"):
+ print("Skip because llvm is not enabled")
+ return
+ mod, params = relay.testing.resnet.get_workload(num_layers=18)
+ with relay.build_config(opt_level=3):
+ complied_graph_lib = relay.build_module.build(mod, "llvm", params=params)
+ data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
+
+ # raw api
+ ctx = tvm.cpu()
+ try:
+ gmod = complied_graph_lib['debug_create']('default', ctx)
+ except:
+ print("Skip because debug graph_runtime not enabled")
+ return
+ set_input = gmod["set_input"]
+ run = gmod["run"]
+ get_output = gmod["get_output"]
+ set_input("data", tvm.nd.array(data))
+ run()
+ out = get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+ # debug graph runtime wrapper
+ debug_g_mod = debug_runtime.GraphModuleDebug(complied_graph_lib['debug_create']('default', ctx), [ctx],
+ complied_graph_lib.get_json(), None)
+ debug_g_mod.set_input("data", data)
+ debug_g_mod.run()
+ out = debug_g_mod.get_output(0).asnumpy()
+ tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
+
+if __name__ == "__main__":
+ test_legacy_compatibility()
+ test_cpu()
+ test_gpu()
+ test_mod_export()
+ test_remove_package_params()
+ test_debug_graph_runtime()