Add debug mode to tempdir() (#5581)
authorAndrew Reusch <areusch@octoml.ai>
Fri, 15 May 2020 21:39:45 +0000 (14:39 -0700)
committerGitHub <noreply@github.com>
Fri, 15 May 2020 21:39:45 +0000 (14:39 -0700)
python/tvm/contrib/util.py
tests/python/contrib/test_util.py [new file with mode: 0644]

index e980e55..8f6dfc7 100644 (file)
 # under the License.
 """Common system utilities"""
 import atexit
+import contextlib
+import datetime
 import os
 import tempfile
+import threading
 import shutil
 try:
     import fcntl
 except ImportError:
     fcntl = None
 
+
+class DirectoryCreatedPastAtExit(Exception):
+    """Raised when a TempDirectory is created after the atexit hook runs."""
+
 class TempDirectory(object):
     """Helper object to manage temp directory during testing.
 
     Automatically removes the directory when it went out of scope.
     """
 
+    # When True, all TempDirectory are *NOT* deleted and instead live inside a predicable directory
+    # tree.
+    _KEEP_FOR_DEBUG = False
+
+    # In debug mode, each tempdir is named after the sequence
+    _NUM_TEMPDIR_CREATED = 0
+    _NUM_TEMPDIR_CREATED_LOCK = threading.Lock()
+    @classmethod
+    def _increment_num_tempdir_created(cls):
+        with cls._NUM_TEMPDIR_CREATED_LOCK:
+            to_return = cls._NUM_TEMPDIR_CREATED
+            cls._NUM_TEMPDIR_CREATED += 1
+
+        return to_return
+
+    _DEBUG_PARENT_DIR = None
+    @classmethod
+    def _get_debug_parent_dir(cls):
+        if cls._DEBUG_PARENT_DIR is None:
+            all_parents = f'{tempfile.gettempdir()}/tvm-debug-mode-tempdirs'
+            if not os.path.isdir(all_parents):
+                os.makedirs(all_parents)
+            cls._DEBUG_PARENT_DIR = tempfile.mkdtemp(
+                prefix=datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S___'), dir=all_parents)
+        return cls._DEBUG_PARENT_DIR
+
     TEMPDIRS = set()
     @classmethod
     def remove_tempdirs(cls):
@@ -42,20 +75,42 @@ class TempDirectory(object):
 
         cls.TEMPDIRS = None
 
+    @classmethod
+    @contextlib.contextmanager
+    def set_keep_for_debug(cls, set_to=True):
+        """Keep temporary directories past program exit for debugging."""
+        old_keep_for_debug = cls._KEEP_FOR_DEBUG
+        try:
+            cls._KEEP_FOR_DEBUG = set_to
+            yield
+        finally:
+            cls._KEEP_FOR_DEBUG = old_keep_for_debug
+
     def __init__(self, custom_path=None):
+        if self.TEMPDIRS is None:
+            raise DirectoryCreatedPastAtExit()
+
+        self._created_with_keep_for_debug = self._KEEP_FOR_DEBUG
         if custom_path:
             os.mkdir(custom_path)
             self.temp_dir = custom_path
         else:
-            self.temp_dir = tempfile.mkdtemp()
+            if self._created_with_keep_for_debug:
+                parent_dir = self._get_debug_parent_dir()
+                self.temp_dir = f'{parent_dir}/{self._increment_num_tempdir_created():05d}'
+                os.mkdir(self.temp_dir)
+            else:
+                self.temp_dir = tempfile.mkdtemp()
 
-        self.TEMPDIRS.add(self.temp_dir)
+        if not self._created_with_keep_for_debug:
+            self.TEMPDIRS.add(self.temp_dir)
 
     def remove(self):
         """Remote the tmp dir"""
         if self.temp_dir:
-            shutil.rmtree(self.temp_dir, ignore_errors=True)
-            self.TEMPDIRS.remove(self.temp_dir)
+            if not self._created_with_keep_for_debug:
+                shutil.rmtree(self.temp_dir, ignore_errors=True)
+                self.TEMPDIRS.remove(self.temp_dir)
             self.temp_dir = None
 
     def __del__(self):
diff --git a/tests/python/contrib/test_util.py b/tests/python/contrib/test_util.py
new file mode 100644 (file)
index 0000000..55a2b76
--- /dev/null
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for functions in tvm/python/tvm/contrib/util.py."""
+
+import datetime
+import os
+import shutil
+from tvm.contrib import util
+
+
+def validate_debug_dir_path(temp_dir, expected_basename):
+  dirname, basename = os.path.split(temp_dir.temp_dir)
+  assert basename == expected_basename, 'unexpected basename: %s' % (basename,)
+
+  parent_dir = os.path.basename(dirname)
+  create_time = datetime.datetime.strptime(parent_dir.split('___', 1)[0], '%Y-%m-%dT%H-%M-%S')
+  assert abs(datetime.datetime.now() - create_time) < datetime.timedelta(seconds=60)
+
+
+
+def test_tempdir():
+  assert util.TempDirectory._KEEP_FOR_DEBUG == False, "don't submit with KEEP_FOR_DEBUG == True"
+
+  temp_dir = util.tempdir()
+  assert os.path.exists(temp_dir.temp_dir)
+
+  old_debug_mode = util.TempDirectory._KEEP_FOR_DEBUG
+  try:
+    for temp_dir_number in range(0, 3):
+      with util.TempDirectory.set_keep_for_debug():
+        debug_temp_dir = util.tempdir()
+        try:
+          validate_debug_dir_path(debug_temp_dir, '0000' + str(temp_dir_number))
+        finally:
+          shutil.rmtree(debug_temp_dir.temp_dir)
+
+    with util.TempDirectory.set_keep_for_debug():
+      # Create 2 temp_dir within the same session.
+      debug_temp_dir = util.tempdir()
+      try:
+        validate_debug_dir_path(debug_temp_dir, '00003')
+      finally:
+        shutil.rmtree(debug_temp_dir.temp_dir)
+
+      debug_temp_dir = util.tempdir()
+      try:
+        validate_debug_dir_path(debug_temp_dir, '00004')
+      finally:
+        shutil.rmtree(debug_temp_dir.temp_dir)
+
+      with util.TempDirectory.set_keep_for_debug(False):
+        debug_temp_dir = util.tempdir()  # This one should get deleted.
+
+        # Simulate atexit hook
+        util.TempDirectory.remove_tempdirs()
+
+        # Calling twice should be a no-op.
+        util.TempDirectory.remove_tempdirs()
+
+        # Creating a new TempDirectory should fail now
+        try:
+          util.tempdir()
+          assert False, 'creation should fail'
+        except util.DirectoryCreatedPastAtExit:
+          pass
+
+  finally:
+    util.TempDirectory.DEBUG_MODE = old_debug_mode
+
+
+if __name__ == '__main__':
+  test_tempdir()