From 9fcde21313fd947d379c3c96d114f080676c9308 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Wed, 15 Jul 2020 11:07:43 +0800 Subject: [PATCH] [RUNTIME] Support module based interface runtime (#5753) --- python/tvm/contrib/debugger/debug_runtime.py | 13 +- python/tvm/contrib/graph_runtime.py | 12 +- python/tvm/relay/backend/graph_runtime_factory.py | 84 ++++ python/tvm/relay/build_module.py | 9 +- src/runtime/graph/graph_runtime_factory.cc | 175 +++++++ src/runtime/graph/graph_runtime_factory.h | 131 ++++++ src/runtime/module.cc | 3 +- .../test_runtime_module_based_interface.py | 520 +++++++++++++++++++++ 8 files changed, 927 insertions(+), 20 deletions(-) create mode 100644 python/tvm/relay/backend/graph_runtime_factory.py create mode 100644 src/runtime/graph/graph_runtime_factory.cc create mode 100644 src/runtime/graph/graph_runtime_factory.h create mode 100644 tests/python/unittest/test_runtime_module_based_interface.py diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 848d7f5..1f96a86 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -35,10 +35,10 @@ def create(graph_json_str, libmod, ctx, dump_root=None): Parameters ---------- - graph_json_str : str or graph class + graph_json_str : str The graph to be deployed in json format output by graph compiler. - The graph can only contain one operator(tvm_op) that - points to the name of PackedFunc in the libmod. + The graph can contain operator(tvm_op) that points to the name + of PackedFunc in the libmod. libmod : tvm.Module The module of the corresponding function. @@ -54,11 +54,8 @@ def create(graph_json_str, libmod, ctx, dump_root=None): graph_module : GraphModuleDebug Debug Runtime graph module that can be used to execute the graph. """ - if not isinstance(graph_json_str, string_types): - try: - graph_json_str = graph_json_str._tvm_graph_json() - except AttributeError: - raise ValueError("Type %s is not supported" % type(graph_json_str)) + assert isinstance(graph_json_str, string_types) + try: ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) if num_rpc_ctx == len(ctx): diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 9b714a8..ec102f5 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -29,10 +29,10 @@ def create(graph_json_str, libmod, ctx): Parameters ---------- - graph_json_str : str or graph class + graph_json_str : str The graph to be deployed in json format output by json graph. - The graph can only contain one operator(tvm_op) that - points to the name of PackedFunc in the libmod. + The graph can contain operator(tvm_op) that points to the name + of PackedFunc in the libmod. libmod : tvm.runtime.Module The module of the corresponding function @@ -48,11 +48,7 @@ def create(graph_json_str, libmod, ctx): graph_module : GraphModule Runtime graph module that can be used to execute the graph. """ - if not isinstance(graph_json_str, string_types): - try: - graph_json_str = graph_json_str._tvm_graph_json() - except AttributeError: - raise ValueError("Type %s is not supported" % type(graph_json_str)) + assert isinstance(graph_json_str, string_types) ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx) diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py new file mode 100644 index 0000000..f7ed122 --- /dev/null +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Graph runtime factory.""" +import warnings +from tvm._ffi.base import string_types +from tvm._ffi.registry import get_global_func +from tvm.runtime import ndarray + +class GraphRuntimeFactoryModule(object): + """Graph runtime factory module. + This is a module of graph runtime factory + + Parameters + ---------- + graph_json_str : str + The graph to be deployed in json format output by graph compiler. + The graph can contain operator(tvm_op) that points to the name of + PackedFunc in the libmod. + libmod : tvm.Module + The module of the corresponding function + libmod_name: str + The name of module + params : dict of str to NDArray + The parameters of module + """ + + def __init__(self, graph_json_str, libmod, libmod_name, params): + assert isinstance(graph_json_str, string_types) + fcreate = get_global_func("tvm.graph_runtime_factory.create") + args = [] + for k, v in params.items(): + args.append(k) + args.append(ndarray.array(v)) + self.module = fcreate(graph_json_str, libmod, libmod_name, *args) + self.graph_json = graph_json_str + self.lib = libmod + self.libmod_name = libmod_name + self.params = params + self.iter_cnt = 0 + + def export_library(self, file_name, fcompile=None, addons=None, **kwargs): + return self.module.export_library(file_name, fcompile, addons, **kwargs) + + # Sometimes we want to get params explicitly. + # For example, we want to save its params value to + # an independent file. + def get_params(self): + return self.params + + def get_json(self): + return self.graph_json + + def __getitem__(self, item): + return self.module.__getitem__(item) + + def __iter__(self): + warnings.warn( + "legacy graph runtime behaviour of producing json / lib / params will be " + "removed in the next release ", + DeprecationWarning, 2) + return self + + def __next__(self): + if self.iter_cnt > 2: + raise StopIteration + + objs = [self.graph_json, self.lib, self.params] + obj = objs[self.iter_cnt] + self.iter_cnt += 1 + return obj diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index a28ab85..896f334 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -30,6 +30,7 @@ from . import _build_module from . import ty as _ty from . import expr as _expr from . import function as _function +from .backend import graph_runtime_factory as _graph_runtime_factory from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -181,7 +182,7 @@ class BuildModule(object): return ret -def build(mod, target=None, target_host=None, params=None): +def build(mod, target=None, target_host=None, params=None, mod_name='default'): """Helper function that builds a Relay function to run on TVM graph runtime. @@ -208,6 +209,9 @@ def build(mod, target=None, target_host=None, params=None): Input parameters to the graph that do not change during inference time. Used for constant folding. + mod_name: Optional[str] + The module name we will build + Returns ------- graph_json : str @@ -249,7 +253,8 @@ def build(mod, target=None, target_host=None, params=None): with tophub_context: bld_mod = BuildModule() graph_json, mod, params = bld_mod.build(mod, target, target_host, params) - return graph_json, mod, params + mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params) + return mod def optimize(mod, target=None, params=None): diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc new file mode 100644 index 0000000..aa35afa --- /dev/null +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime_factory.cc + * \brief Graph runtime factory implementations + */ + +#include "./graph_runtime_factory.h" + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { + +GraphRuntimeFactory::GraphRuntimeFactory( + const std::string& graph_json, + const std::unordered_map& params, + const std::string& module_name) { + graph_json_ = graph_json; + params_ = params; + module_name_ = module_name; +} + +PackedFunc GraphRuntimeFactory::GetFunction( + const std::string& name, const tvm::runtime::ObjectPtr& sptr_to_self) { + if (name == module_name_) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::vector contexts; + for (int i = 0; i < args.num_args; ++i) { + contexts.emplace_back(args[i].operator TVMContext()); + } + *rv = this->RuntimeCreate(contexts); + }); + } else if (name == "debug_create") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.size(), 2); + std::string module_name = args[0].operator String(); + CHECK(module_name == module_name_) << "Currently we only support single model for now."; + std::vector contexts; + for (int i = 1; i < args.num_args; ++i) { + contexts.emplace_back(args[i].operator TVMContext()); + } + *rv = this->DebugRuntimeCreate(contexts); + }); + } else if (name == "remove_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::unordered_map empty_params{}; + auto exec = + make_object(this->graph_json_, empty_params, this->module_name_); + exec->Import(this->imports_[0]); + *rv = Module(exec); + }); + } else { + return PackedFunc(); + } +} + +void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { + stream->Write(graph_json_); + std::vector names; + std::vector arrays; + for (const auto& v : params_) { + names.emplace_back(v.first); + arrays.emplace_back(const_cast(v.second.operator->())); + } + uint64_t sz = arrays.size(); + CHECK(sz == names.size()); + stream->Write(sz); + stream->Write(names); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(stream, arrays[i]); + } + stream->Write(module_name_); +} + +Module GraphRuntimeFactory::RuntimeCreate(const std::vector& ctxs) { + auto exec = make_object(); + exec->Init(this->graph_json_, this->imports_[0], ctxs); + // set params + SetParams(exec.get(), this->params_); + return Module(exec); +} + +Module GraphRuntimeFactory::DebugRuntimeCreate(const std::vector& ctxs) { + const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_debug.create"); + CHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_debug.create in registry. " + "Do you enable debug graph runtime build?"; + // Debug runtime create packed function will call GetAllContexs, so we unpack the ctxs. + std::vector unpacked_ctxs; + for (const auto& ctx : ctxs) { + unpacked_ctxs.emplace_back(ctx.device_type); + unpacked_ctxs.emplace_back(ctx.device_id); + } + size_t args_size = unpacked_ctxs.size() + 2; + std::vector values(args_size); + std::vector codes(args_size); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, this->graph_json_); + setter(1, this->imports_[0]); + for (size_t i = 0; i < unpacked_ctxs.size(); ++i) { + setter(i + 2, unpacked_ctxs[i]); + } + TVMRetValue rv; + pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv); + Module mod = rv.operator Module(); + // debug graph runtime is one child class of graph runtime. + SetParams(const_cast(mod.as()), this->params_); + return mod; +} + +Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::string graph_json; + std::unordered_map params; + std::string module_name; + CHECK(stream->Read(&graph_json)); + uint64_t sz; + CHECK(stream->Read(&sz)); + std::vector names; + CHECK(stream->Read(&names)); + CHECK(sz == names.size()); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::NDArray temp; + temp.Load(stream); + params[names[i]] = temp; + } + CHECK(stream->Read(&module_name)); + auto exec = make_object(graph_json, params, module_name); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 3) << "The expected number of arguments for " + "graph_runtime_factory.create needs at least 3, " + "but it has " + << args.num_args; + // The argument order is graph_json, module, module_name, params. + CHECK_EQ((args.size() - 3) % 2, 0); + std::unordered_map params; + for (size_t i = 3; i < static_cast(args.size()); i += 2) { + std::string name = args[i].operator String(); + params[name] = args[i + 1].operator tvm::runtime::NDArray(); + } + auto exec = make_object(args[0], params, args[2]); + exec->Import(args[1]); + *rv = Module(exec); +}); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphRuntimeFactory") + .set_body_typed(GraphRuntimeFactoryModuleLoadBinary); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h new file mode 100644 index 0000000..98fb27c --- /dev/null +++ b/src/runtime/graph/graph_runtime_factory.h @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/graph_runtime_factory.h + * \brief Graph runtime factory creating graph runtime. + */ + +#ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ +#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "./graph_runtime.h" + +namespace tvm { +namespace runtime { + +class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { + public: + /*! + * \brief Construct the GraphRuntimeFactory. + * \param graph_json The execution graph. + * \param params The params of graph. + * \param module_name The module name of graph. + */ + GraphRuntimeFactory(const std::string& graph_json, + const std::unordered_map& params, + const std::string& module_name = "default"); + + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const override { return "GraphRuntimeFactory"; } + + /*! + * \brief Save the module to binary stream. + * \param stream The binary stream to save to. + */ + void SaveToBinary(dmlc::Stream* stream) override; + + /*! + * \brief Create a specific runtime module + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + * \return created runtime module + */ + Module RuntimeCreate(const std::vector& ctxs); + + /*! + * \brief Create a specific debug runtime module + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + * \return created debug runtime module + */ + Module DebugRuntimeCreate(const std::vector& ctxs); + + /*! + * \brief Set params. + * \param graph_runtime The graph runtime we want to set the params into. + * \param params The graph params value we want to set. + */ + void SetParams(GraphRuntime* graph_runtime, + const std::unordered_map& params) const { + std::unordered_map value = params; + // upload big arrays first to avoid memory issue in rpc mode + std::vector keys; + for (const auto& p : value) { + keys.emplace_back(p.first); + } + std::sort(std::begin(keys), std::end(keys), + [&](const std::string& lhs, const std::string& rhs) -> bool { + auto lhs_size = GetDataSize(value[lhs].ToDLPack()->dl_tensor); + auto rhs_size = GetDataSize(value[rhs].ToDLPack()->dl_tensor); + return lhs_size > rhs_size; + }); + for (const auto& key : keys) { + int in_idx = graph_runtime->GetInputIndex(key); + if (in_idx >= 0) { + graph_runtime->SetInput(in_idx, const_cast(value[key].operator->())); + } + } + } + + protected: + /*! \brief The execution graph. */ + std::string graph_json_; + /*! \brief The params. */ + std::unordered_map params_; + /*! \brief module name */ + std::string module_name_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 46ef6fa..8052467 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -67,8 +67,7 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) if (pf != nullptr) return pf; if (query_imports) { for (Module& m : self->imports_) { - pf = m->GetFunction(name, m.data_); - if (pf != nullptr) return pf; + pf = m.operator->()->GetFunction(name, query_imports); } } return pf; diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py new file mode 100644 index 0000000..5ab4e82 --- /dev/null +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -0,0 +1,520 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +from tvm import relay +from tvm.relay import testing +import tvm +from tvm.contrib import graph_runtime +from tvm.contrib.debugger import debug_runtime + +def verify(data): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) + + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.set_input("data", data) + module.set_input(**graph_params) + module.run() + out = module.get_output(0).asnumpy() + + return out + +def test_legacy_compatibility(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.set_input("data", data) + module.set_input(**graph_params) + module.run() + out = module.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_cpu(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + # raw api + ctx = tvm.cpu() + gmod = complied_graph_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_gpu(): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu() + + # raw api + gmod = complied_graph_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_mod_export(): + def verify_cpu_export(obj_format): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + loaded_lib = tvm.runtime.load_module(path_lib) + ctx = tvm.cpu(0) + gmod = loaded_lib['default'](ctx) + + # raw api + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_gpu_export(obj_format): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + loaded_lib = tvm.runtime.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_rpc_cpu_export(obj_format): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.cpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data, ctx=ctx)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_rpc_gpu_export(obj_format): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.gpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data, ctx=ctx)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + for obj_format in [".so", ".tar"]: + verify_cpu_export(obj_format) + verify_gpu_export(obj_format) + verify_rpc_cpu_export(obj_format) + verify_rpc_gpu_export(obj_format) + +def test_remove_package_params(): + def verify_cpu_remove_package_params(obj_format): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib_no_params = complied_graph_lib["remove_params"]() + complied_graph_lib_no_params.export_library(path_lib) + with open(temp.relpath("deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(complied_graph_lib.get_params())) + loaded_lib = tvm.runtime.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.cpu(0) + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + set_input("data", tvm.nd.array(data)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + gmod.set_input("data", data) + gmod.load_params(loaded_params) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_gpu_remove_package_params(obj_format): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib_no_params = complied_graph_lib["remove_params"]() + complied_graph_lib_no_params.export_library(path_lib) + with open(temp.relpath("deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(complied_graph_lib.get_params())) + loaded_lib = tvm.runtime.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu(0) + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + set_input("data", tvm.nd.array(data)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + gmod.set_input("data", data) + gmod.load_params(loaded_params) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_rpc_cpu_remove_package_params(obj_format): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib_no_params = complied_graph_lib["remove_params"]() + complied_graph_lib_no_params.export_library(path_lib) + path_params = temp.relpath("deploy_param.params") + with open(path_params, "wb") as fo: + fo.write(relay.save_param_dict(complied_graph_lib.get_params())) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.cpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + loaded_params = bytearray(open(path_params, "rb").read()) + set_input("data", tvm.nd.array(data, ctx=ctx)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + loaded_params = bytearray(open(path_params, "rb").read()) + gmod.set_input("data", data) + gmod.load_params(loaded_params) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_rpc_gpu_remove_package_params(obj_format): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib_no_params = complied_graph_lib["remove_params"]() + complied_graph_lib_no_params.export_library(path_lib) + path_params = temp.relpath("deploy_param.params") + with open(path_params, "wb") as fo: + fo.write(relay.save_param_dict(complied_graph_lib.get_params())) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.gpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + loaded_params = bytearray(open(path_params, "rb").read()) + set_input("data", tvm.nd.array(data, ctx=ctx)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + loaded_params = bytearray(open(path_params, "rb").read()) + gmod.set_input("data", data) + gmod.load_params(loaded_params) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + for obj_format in [".so", ".tar"]: + verify_cpu_remove_package_params(obj_format) + verify_gpu_remove_package_params(obj_format) + verify_rpc_cpu_remove_package_params(obj_format) + verify_rpc_gpu_remove_package_params(obj_format) + +def test_debug_graph_runtime(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + + # raw api + ctx = tvm.cpu() + try: + gmod = complied_graph_lib['debug_create']('default', ctx) + except: + print("Skip because debug graph_runtime not enabled") + return + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # debug graph runtime wrapper + debug_g_mod = debug_runtime.GraphModuleDebug(complied_graph_lib['debug_create']('default', ctx), [ctx], + complied_graph_lib.get_json(), None) + debug_g_mod.set_input("data", data) + debug_g_mod.run() + out = debug_g_mod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +if __name__ == "__main__": + test_legacy_compatibility() + test_cpu() + test_gpu() + test_mod_export() + test_remove_package_params() + test_debug_graph_runtime() -- 2.7.4