[RPC][BUGFIX][BACKPORT-0.6] Fix bug in rpc ring buffer shrink (#5516)
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 5 May 2020 23:47:00 +0000 (16:47 -0700)
committerGitHub <noreply@github.com>
Tue, 5 May 2020 23:47:00 +0000 (16:47 -0700)
src/support/ring_buffer.h
tests/python/unittest/test_runtime_rpc.py

index e6e3b04..7a1bcb6 100644 (file)
@@ -49,8 +49,12 @@ class RingBuffer {
     return ring_.size();
   }
   /*!
-   * Reserve capacity to be at least n.
-   * Will only increase capacity if n is bigger than current capacity.
+   *  Reserve capacity to be at least n.
+   *  Will only increase capacity if n is bigger than current capacity.
+   *
+   *  The effect of Reserve only lasts before the next call to Reserve.
+   *  Other functions in the ring buffer can also call into the reserve.
+   *
    * \param n The size of capacity.
    */
   void Reserve(size_t n) {
@@ -63,19 +67,27 @@ class RingBuffer {
           size_t ncopy = head_ptr_ + bytes_available_ - old_size;
           memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
         }
-    } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) {
-        // shrink too large temporary buffer to avoid out of memory on some embedded devices
+    } else if (ring_.size() > n * 8 &&
+               ring_.size() > kInitCapacity) {
+      // shrink too large temporary buffer to
+      // avoid out of memory on some embedded devices
+      if (bytes_available_ != 0) {
+        // move existing bytes to the head.
         size_t old_bytes = bytes_available_;
-
         std::vector<char> tmp(old_bytes);
-
         Read(&tmp[0], old_bytes);
-        ring_.resize(kInitCapacity);
-        ring_.shrink_to_fit();
 
         memcpy(&ring_[0], &tmp[0], old_bytes);
-        head_ptr_ = 0;
         bytes_available_ = old_bytes;
+      }
+      // shrink the ring.
+      size_t new_size  = kInitCapacity;
+      new_size = std::max(new_size, n);
+      new_size = std::max(new_size, bytes_available_);
+
+      ring_.resize(new_size);
+      ring_.shrink_to_fit();
+      head_ptr_ = 0;
     }
   }
 
index 091e942..4e7921b 100644 (file)
@@ -102,6 +102,19 @@ def test_rpc_array():
     fremote(r_cpu)
 
 
+def test_rpc_large_array():
+    # testcase of large array creation
+    server = rpc.Server("localhost")
+    remote = rpc.connect(server.host, server.port)
+    ctx = remote.cpu(0)
+    a_np = np.ones((5041, 720)).astype('float32')
+    b_np = np.ones((720, 192)).astype('float32')
+    a = tvm.nd.array(a_np, ctx)
+    b = tvm.nd.array(b_np, ctx)
+    np.testing.assert_equal(a.asnumpy(), a_np)
+    np.testing.assert_equal(b.asnumpy(), b_np)
+
+
 def test_rpc_echo():
     def check(remote):
         fecho = remote.get_function("testing.echo")
@@ -447,3 +460,4 @@ if __name__ == "__main__":
     test_local_func()
     test_rpc_tracker_register()
     test_rpc_tracker_request()
+    test_rpc_large_array()