# under the License.
"""Common system utilities"""
import atexit
+import contextlib
+import datetime
import os
import tempfile
+import threading
import shutil
try:
import fcntl
except ImportError:
fcntl = None
+
+class DirectoryCreatedPastAtExit(Exception):
+ """Raised when a TempDirectory is created after the atexit hook runs."""
+
class TempDirectory(object):
"""Helper object to manage temp directory during testing.
Automatically removes the directory when it went out of scope.
"""
+ # When True, all TempDirectory are *NOT* deleted and instead live inside a predicable directory
+ # tree.
+ _KEEP_FOR_DEBUG = False
+
+ # In debug mode, each tempdir is named after the sequence
+ _NUM_TEMPDIR_CREATED = 0
+ _NUM_TEMPDIR_CREATED_LOCK = threading.Lock()
+ @classmethod
+ def _increment_num_tempdir_created(cls):
+ with cls._NUM_TEMPDIR_CREATED_LOCK:
+ to_return = cls._NUM_TEMPDIR_CREATED
+ cls._NUM_TEMPDIR_CREATED += 1
+
+ return to_return
+
+ _DEBUG_PARENT_DIR = None
+ @classmethod
+ def _get_debug_parent_dir(cls):
+ if cls._DEBUG_PARENT_DIR is None:
+ all_parents = f'{tempfile.gettempdir()}/tvm-debug-mode-tempdirs'
+ if not os.path.isdir(all_parents):
+ os.makedirs(all_parents)
+ cls._DEBUG_PARENT_DIR = tempfile.mkdtemp(
+ prefix=datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S___'), dir=all_parents)
+ return cls._DEBUG_PARENT_DIR
+
TEMPDIRS = set()
@classmethod
def remove_tempdirs(cls):
cls.TEMPDIRS = None
+ @classmethod
+ @contextlib.contextmanager
+ def set_keep_for_debug(cls, set_to=True):
+ """Keep temporary directories past program exit for debugging."""
+ old_keep_for_debug = cls._KEEP_FOR_DEBUG
+ try:
+ cls._KEEP_FOR_DEBUG = set_to
+ yield
+ finally:
+ cls._KEEP_FOR_DEBUG = old_keep_for_debug
+
def __init__(self, custom_path=None):
+ if self.TEMPDIRS is None:
+ raise DirectoryCreatedPastAtExit()
+
+ self._created_with_keep_for_debug = self._KEEP_FOR_DEBUG
if custom_path:
os.mkdir(custom_path)
self.temp_dir = custom_path
else:
- self.temp_dir = tempfile.mkdtemp()
+ if self._created_with_keep_for_debug:
+ parent_dir = self._get_debug_parent_dir()
+ self.temp_dir = f'{parent_dir}/{self._increment_num_tempdir_created():05d}'
+ os.mkdir(self.temp_dir)
+ else:
+ self.temp_dir = tempfile.mkdtemp()
- self.TEMPDIRS.add(self.temp_dir)
+ if not self._created_with_keep_for_debug:
+ self.TEMPDIRS.add(self.temp_dir)
def remove(self):
"""Remote the tmp dir"""
if self.temp_dir:
- shutil.rmtree(self.temp_dir, ignore_errors=True)
- self.TEMPDIRS.remove(self.temp_dir)
+ if not self._created_with_keep_for_debug:
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
+ self.TEMPDIRS.remove(self.temp_dir)
self.temp_dir = None
def __del__(self):
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for functions in tvm/python/tvm/contrib/util.py."""
+
+import datetime
+import os
+import shutil
+from tvm.contrib import util
+
+
+def validate_debug_dir_path(temp_dir, expected_basename):
+ dirname, basename = os.path.split(temp_dir.temp_dir)
+ assert basename == expected_basename, 'unexpected basename: %s' % (basename,)
+
+ parent_dir = os.path.basename(dirname)
+ create_time = datetime.datetime.strptime(parent_dir.split('___', 1)[0], '%Y-%m-%dT%H-%M-%S')
+ assert abs(datetime.datetime.now() - create_time) < datetime.timedelta(seconds=60)
+
+
+
+def test_tempdir():
+ assert util.TempDirectory._KEEP_FOR_DEBUG == False, "don't submit with KEEP_FOR_DEBUG == True"
+
+ temp_dir = util.tempdir()
+ assert os.path.exists(temp_dir.temp_dir)
+
+ old_debug_mode = util.TempDirectory._KEEP_FOR_DEBUG
+ try:
+ for temp_dir_number in range(0, 3):
+ with util.TempDirectory.set_keep_for_debug():
+ debug_temp_dir = util.tempdir()
+ try:
+ validate_debug_dir_path(debug_temp_dir, '0000' + str(temp_dir_number))
+ finally:
+ shutil.rmtree(debug_temp_dir.temp_dir)
+
+ with util.TempDirectory.set_keep_for_debug():
+ # Create 2 temp_dir within the same session.
+ debug_temp_dir = util.tempdir()
+ try:
+ validate_debug_dir_path(debug_temp_dir, '00003')
+ finally:
+ shutil.rmtree(debug_temp_dir.temp_dir)
+
+ debug_temp_dir = util.tempdir()
+ try:
+ validate_debug_dir_path(debug_temp_dir, '00004')
+ finally:
+ shutil.rmtree(debug_temp_dir.temp_dir)
+
+ with util.TempDirectory.set_keep_for_debug(False):
+ debug_temp_dir = util.tempdir() # This one should get deleted.
+
+ # Simulate atexit hook
+ util.TempDirectory.remove_tempdirs()
+
+ # Calling twice should be a no-op.
+ util.TempDirectory.remove_tempdirs()
+
+ # Creating a new TempDirectory should fail now
+ try:
+ util.tempdir()
+ assert False, 'creation should fail'
+ except util.DirectoryCreatedPastAtExit:
+ pass
+
+ finally:
+ util.TempDirectory.DEBUG_MODE = old_debug_mode
+
+
+if __name__ == '__main__':
+ test_tempdir()