[Runtime] Allow for parameter sharing in GraphRuntime (#3384)
authorAndrew Tulloch <andrew@tullo.ch>
Tue, 25 Jun 2019 04:06:20 +0000 (21:06 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 25 Jun 2019 04:06:20 +0000 (21:06 -0700)
Summary:

In multi-threaded applications where we have multiple inferences on the
same model in parallel (consider e.g. a TTS system handling multiple
requests), it can be useful to share the parameters of a model amongst
these multiple instances. This improves the cache utilization behaviour
of the system, as multiple cores can use the same set of weights instead
of evicting the identical copies of weights in a shared cache.

As the underlying `NDArray` instances in `data_entry_` implement a
ref-counted based sharing system, this is a simple modification of the
`GraphRuntime::LoadParams` logic to instead copy parameters from an
existing GraphRuntime instance. This is a little ugly in that we need
both the pre-existing GraphRuntime instance, as well as the 'serialized'
params (since we need to know the set of names we should copy), but
without imposing additional assumptions (i.e. storing the set of param
names in GraphRuntime, and enforcing that shared param names are
identical to the parameters set in the preceding `LoadParams` call),
this seems unavoidable.

Test Plan:

Unit test added.

python/tvm/contrib/graph_runtime.py
src/runtime/graph/graph_runtime.cc
src/runtime/graph/graph_runtime.h
tests/python/unittest/test_runtime_graph.py

index 4d0698a..0c9ce40 100644 (file)
@@ -129,6 +129,7 @@ class GraphModule(object):
         self._get_input = module["get_input"]
         self._get_num_outputs = module["get_num_outputs"]
         self._load_params = module["load_params"]
+        self._share_params = module["share_params"]
 
     def set_input(self, key=None, value=None, **params):
         """Set inputs to the module via kwargs
@@ -234,6 +235,19 @@ class GraphModule(object):
         """
         self._load_params(bytearray(params_bytes))
 
+    def share_params(self, other, params_bytes):
+        """Share parameters from pre-existing GraphRuntime instance.
+
+        Parameters
+        ----------
+        other: GraphRuntime
+            The parent GraphRuntime from which this instance should share
+            it's parameters.
+        params_bytes : bytearray
+            The serialized parameter dict (used only for the parameter names).
+        """
+        self._share_params(other.module, bytearray(params_bytes))
+
     def __getitem__(self, key):
         """Get internal module function
 
index 960d509..cc37a85 100644 (file)
@@ -184,6 +184,32 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
   }
 }
 
+  void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
+    uint64_t header, reserved;
+    CHECK(strm->Read(&header))
+      << "Invalid parameters file format";
+    CHECK(header == kTVMNDArrayListMagic)
+      << "Invalid parameters file format";
+    CHECK(strm->Read(&reserved))
+      << "Invalid parameters file format";
+  std::vector<std::string> names;
+  CHECK(strm->Read(&names)) << "Invalid parameters file format";
+  uint64_t sz;
+  strm->Read(&sz);
+  size_t size = static_cast<size_t>(sz);
+  CHECK(size == names.size()) << "Invalid parameters file format";
+  for (size_t i = 0; i < size; ++i) {
+    int in_idx = GetInputIndex(names[i]);
+    CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
+    uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
+    CHECK_LT(eid, data_entry_.size());
+    CHECK_EQ(data_entry_[eid].use_count(), 1);
+    data_entry_[eid] = other.GetInput(GetInputIndex(names[i]));
+    CHECK_GT(data_entry_[eid].use_count(), 1);
+  }
+  this->SetupOpExecs();
+}
+
 void GraphRuntime::SetupStorage() {
   // Grab saved optimization plan from graph.
   std::vector<TVMType> vtype;
@@ -372,6 +398,14 @@ PackedFunc GraphRuntime::GetFunction(
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         this->LoadParams(args[0].operator std::string());
       });
+  } else if (name == "share_params") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        const auto& module = args[0].operator Module();
+        CHECK_EQ(module.operator->()->type_key(), "GraphRuntime");
+        const auto& param_blob = args[1].operator std::string();
+        dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
+        this->ShareParams(dynamic_cast<const GraphRuntime&>(*module.operator->()), &strm);
+      });
   } else {
     return PackedFunc();
   }
index 5298f22..e3f5815 100644 (file)
@@ -147,10 +147,19 @@ class GraphRuntime : public ModuleNode {
    * \param param_blob A binary blob of parameter.
    */
   void LoadParams(const std::string& param_blob);
- /*!
-  * \brief Get total number of nodes.
-  * \return Total number of nodes.
-  */
+
+  /*!
+   * \brief Share parameters from pre-existing GraphRuntime instance.
+   * \param other A GraphRuntime instance, previously with |LoadParams| called with the
+   * identical input |param_blob|.
+   * \param strm The input stream.
+   */
+  void ShareParams(const GraphRuntime& other, dmlc::Stream* strm);
+
+  /*!
+   * \brief Get total number of nodes.
+   * \return Total number of nodes.
+   */
   uint32_t GetNumOfNodes() const {
     return static_cast<uint32_t>(nodes_.size());
   }
index 20af8a0..f331f5b 100644 (file)
@@ -81,8 +81,46 @@ def test_graph_simple():
         out = mod.get_output(0, out)
         np.testing.assert_equal(out.asnumpy(), a + 1)
 
+    def check_sharing():
+        from tvm import relay
+        x = relay.var('x', shape=(1, 10))
+        y = relay.var('y', shape=(1, 10))
+        z = relay.add(x, y)
+        func = relay.Function([x, y], z)
+
+        x_in = np.ones((1, 10)).astype("float32")
+        params = {'x': x_in}
+        graph, lib, params = relay.build(func, target="llvm", params=params)
+
+        if not tvm.module.enabled("llvm"):
+            print("Skip because llvm is not enabled")
+            return
+        mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0))
+        mod_shared.load_params(relay.save_param_dict(params))
+        num_mods = 10
+        mods = [graph_runtime.create(graph, lib, tvm.cpu(0))
+                for _ in range(num_mods)]
+
+        for mod in mods:
+            mod.share_params(mod_shared, relay.save_param_dict(params))
+
+        a = np.random.uniform(size=(1, 10)).astype("float32")
+        for mod in mods:
+            mod.run(y=a)
+            out = mod.get_output(0, tvm.nd.empty((1, 10)))
+            np.testing.assert_equal(out.asnumpy(), x_in + a)
+
+        # Explicitly delete the shared module and verify correctness.
+        del mod_shared
+        for mod in mods:
+            mod.run(y=a)
+            out = mod.get_output(0, tvm.nd.empty((1, 10)))
+            np.testing.assert_equal(out.asnumpy(), x_in + a)
+            del mod
+
     check_verify()
     check_remote()
+    check_sharing()
 
 if __name__ == "__main__":
     test_graph_simple()