From: Andrew Tulloch Date: Tue, 25 Jun 2019 04:06:20 +0000 (-0700) Subject: [Runtime] Allow for parameter sharing in GraphRuntime (#3384) X-Git-Tag: upstream/0.7.0~2259 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=32be34a07f72e1ac008b385a07ca1cb66ec59e6e;p=platform%2Fupstream%2Ftvm.git [Runtime] Allow for parameter sharing in GraphRuntime (#3384) Summary: In multi-threaded applications where we have multiple inferences on the same model in parallel (consider e.g. a TTS system handling multiple requests), it can be useful to share the parameters of a model amongst these multiple instances. This improves the cache utilization behaviour of the system, as multiple cores can use the same set of weights instead of evicting the identical copies of weights in a shared cache. As the underlying `NDArray` instances in `data_entry_` implement a ref-counted based sharing system, this is a simple modification of the `GraphRuntime::LoadParams` logic to instead copy parameters from an existing GraphRuntime instance. This is a little ugly in that we need both the pre-existing GraphRuntime instance, as well as the 'serialized' params (since we need to know the set of names we should copy), but without imposing additional assumptions (i.e. storing the set of param names in GraphRuntime, and enforcing that shared param names are identical to the parameters set in the preceding `LoadParams` call), this seems unavoidable. Test Plan: Unit test added. --- diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 4d0698a..0c9ce40 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -129,6 +129,7 @@ class GraphModule(object): self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] self._load_params = module["load_params"] + self._share_params = module["share_params"] def set_input(self, key=None, value=None, **params): """Set inputs to the module via kwargs @@ -234,6 +235,19 @@ class GraphModule(object): """ self._load_params(bytearray(params_bytes)) + def share_params(self, other, params_bytes): + """Share parameters from pre-existing GraphRuntime instance. + + Parameters + ---------- + other: GraphRuntime + The parent GraphRuntime from which this instance should share + it's parameters. + params_bytes : bytearray + The serialized parameter dict (used only for the parameter names). + """ + self._share_params(other.module, bytearray(params_bytes)) + def __getitem__(self, key): """Get internal module function diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 960d509..cc37a85 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -184,6 +184,32 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { } } + void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { + uint64_t header, reserved; + CHECK(strm->Read(&header)) + << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) + << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) + << "Invalid parameters file format"; + std::vector names; + CHECK(strm->Read(&names)) << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz); + size_t size = static_cast(sz); + CHECK(size == names.size()) << "Invalid parameters file format"; + for (size_t i = 0; i < size; ++i) { + int in_idx = GetInputIndex(names[i]); + CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; + uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); + CHECK_LT(eid, data_entry_.size()); + CHECK_EQ(data_entry_[eid].use_count(), 1); + data_entry_[eid] = other.GetInput(GetInputIndex(names[i])); + CHECK_GT(data_entry_[eid].use_count(), 1); + } + this->SetupOpExecs(); +} + void GraphRuntime::SetupStorage() { // Grab saved optimization plan from graph. std::vector vtype; @@ -372,6 +398,14 @@ PackedFunc GraphRuntime::GetFunction( return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->LoadParams(args[0].operator std::string()); }); + } else if (name == "share_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + const auto& module = args[0].operator Module(); + CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); + const auto& param_blob = args[1].operator std::string(); + dmlc::MemoryStringStream strm(const_cast(¶m_blob)); + this->ShareParams(dynamic_cast(*module.operator->()), &strm); + }); } else { return PackedFunc(); } diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 5298f22..e3f5815 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -147,10 +147,19 @@ class GraphRuntime : public ModuleNode { * \param param_blob A binary blob of parameter. */ void LoadParams(const std::string& param_blob); - /*! - * \brief Get total number of nodes. - * \return Total number of nodes. - */ + + /*! + * \brief Share parameters from pre-existing GraphRuntime instance. + * \param other A GraphRuntime instance, previously with |LoadParams| called with the + * identical input |param_blob|. + * \param strm The input stream. + */ + void ShareParams(const GraphRuntime& other, dmlc::Stream* strm); + + /*! + * \brief Get total number of nodes. + * \return Total number of nodes. + */ uint32_t GetNumOfNodes() const { return static_cast(nodes_.size()); } diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index 20af8a0..f331f5b 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -81,8 +81,46 @@ def test_graph_simple(): out = mod.get_output(0, out) np.testing.assert_equal(out.asnumpy(), a + 1) + def check_sharing(): + from tvm import relay + x = relay.var('x', shape=(1, 10)) + y = relay.var('y', shape=(1, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + x_in = np.ones((1, 10)).astype("float32") + params = {'x': x_in} + graph, lib, params = relay.build(func, target="llvm", params=params) + + if not tvm.module.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0)) + mod_shared.load_params(relay.save_param_dict(params)) + num_mods = 10 + mods = [graph_runtime.create(graph, lib, tvm.cpu(0)) + for _ in range(num_mods)] + + for mod in mods: + mod.share_params(mod_shared, relay.save_param_dict(params)) + + a = np.random.uniform(size=(1, 10)).astype("float32") + for mod in mods: + mod.run(y=a) + out = mod.get_output(0, tvm.nd.empty((1, 10))) + np.testing.assert_equal(out.asnumpy(), x_in + a) + + # Explicitly delete the shared module and verify correctness. + del mod_shared + for mod in mods: + mod.run(y=a) + out = mod.get_output(0, tvm.nd.empty((1, 10))) + np.testing.assert_equal(out.asnumpy(), x_in + a) + del mod + check_verify() check_remote() + check_sharing() if __name__ == "__main__": test_graph_simple()