[RPC] Better handle tempdir if subprocess killed. (#3574)
authorBalint Cristian <cristian.balint@gmail.com>
Fri, 19 Jul 2019 16:22:46 +0000 (19:22 +0300)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 19 Jul 2019 16:22:46 +0000 (09:22 -0700)
python/tvm/contrib/util.py
python/tvm/module.py
python/tvm/rpc/server.py

index 8db84ca..2ab370b 100644 (file)
@@ -30,8 +30,12 @@ class TempDirectory(object):
 
     Automatically removes the directory when it went out of scope.
     """
-    def __init__(self):
-        self.temp_dir = tempfile.mkdtemp()
+    def __init__(self, custom_path=None):
+        if custom_path:
+            os.mkdir(custom_path)
+            self.temp_dir = custom_path
+        else:
+            self.temp_dir = tempfile.mkdtemp()
         self._rmtree = shutil.rmtree
 
     def remove(self):
@@ -69,15 +73,20 @@ class TempDirectory(object):
         return os.listdir(self.temp_dir)
 
 
-def tempdir():
+def tempdir(custom_path=None):
     """Create temp dir which deletes the contents when exit.
 
+    Parameters
+    ----------
+    custom_path : str, optional
+        Manually specify the exact temp dir path
+
     Returns
     -------
     temp : TempDirectory
         The temp directory object
     """
-    return TempDirectory()
+    return TempDirectory(custom_path)
 
 
 class FileLock(object):
index ae3a117..b6fc0f5 100644 (file)
@@ -251,7 +251,7 @@ def load(path, fmt=""):
         _cc.create_shared(path + ".so", path)
         path += ".so"
     elif path.endswith(".tar"):
-        tar_temp = _util.tempdir()
+        tar_temp = _util.tempdir(custom_path=path.replace('.tar', ''))
         _tar.untar(path, tar_temp.temp_dir)
         files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
         _cc.create_shared(path + ".so", files)
index 5fffbb2..22b87c6 100644 (file)
@@ -50,9 +50,12 @@ from . base import TrackerCode
 
 logger = logging.getLogger('RPCServer')
 
-def _server_env(load_library):
+def _server_env(load_library, work_path=None):
     """Server environment function return temp dir"""
-    temp = util.tempdir()
+    if work_path:
+        temp = work_path
+    else:
+        temp = util.tempdir()
 
     # pylint: disable=unused-variable
     @register_func("tvm.rpc.server.workpath")
@@ -76,16 +79,15 @@ def _server_env(load_library):
     temp.libs = libs
     return temp
 
-
-def _serve_loop(sock, addr, load_library):
+def _serve_loop(sock, addr, load_library, work_path=None):
     """Server loop"""
     sockfd = sock.fileno()
-    temp = _server_env(load_library)
+    temp = _server_env(load_library, work_path)
     base._ServerLoop(sockfd)
-    temp.remove()
+    if not work_path:
+        temp.remove()
     logger.info("Finish serving %s", addr)
 
-
 def _parse_server_opt(opts):
     # parse client options
     ret = {}
@@ -196,9 +198,10 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
             raise exc
 
         # step 3: serving
+        work_path = util.tempdir()
         logger.info("connection from %s", addr)
         server_proc = multiprocessing.Process(target=_serve_loop,
-                                              args=(conn, addr, load_library))
+                                              args=(conn, addr, load_library, work_path))
         server_proc.deamon = True
         server_proc.start()
         # close from our side.
@@ -208,6 +211,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
         if server_proc.is_alive():
             logger.info("Timeout in RPC session, kill..")
             server_proc.terminate()
+        work_path.remove()
 
 
 def _connect_proxy_loop(addr, key, load_library):