2 # Copyright 2020 The Pigweed Authors
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
5 # use this file except in compliance with the License. You may obtain a copy of
8 # https://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations under
15 """File Helper Functions."""
28 from pathlib import Path
29 from typing import List
31 _LOG = logging.getLogger(__name__)
34 class InvalidChecksumError(Exception):
38 def find_files(starting_dir: str,
40 directories_only=False) -> List[str]:
41 original_working_dir = os.getcwd()
42 if not (os.path.exists(starting_dir) and os.path.isdir(starting_dir)):
43 raise FileNotFoundError(
44 "Directory '{}' does not exist.".format(starting_dir))
46 os.chdir(starting_dir)
48 for pattern in patterns:
49 for file_path in glob.glob(pattern, recursive=True):
50 if not directories_only or (directories_only
51 and os.path.isdir(file_path)):
52 files.append(file_path)
53 os.chdir(original_working_dir)
57 def sha256_sum(file_name):
58 hash_sha256 = hashlib.sha256()
59 with open(file_name, "rb") as file_handle:
60 for chunk in iter(lambda: file_handle.read(4096), b""):
61 hash_sha256.update(chunk)
62 return hash_sha256.hexdigest()
65 def md5_sum(file_name):
66 hash_md5 = hashlib.md5()
67 with open(file_name, "rb") as file_handle:
68 for chunk in iter(lambda: file_handle.read(4096), b""):
69 hash_md5.update(chunk)
70 return hash_md5.hexdigest()
73 def verify_file_checksum(file_path,
75 sum_function=sha256_sum):
76 downloaded_checksum = sum_function(file_path)
77 if downloaded_checksum != expected_checksum:
78 raise InvalidChecksumError(
79 f"Invalid {sum_function.__name__}\n"
80 f"{downloaded_checksum} {os.path.basename(file_path)}\n"
81 f"{expected_checksum} (expected)\n\n"
82 "Please delete this file and try again:\n"
85 _LOG.debug(" %s:", sum_function.__name__)
86 _LOG.debug(" %s %s", downloaded_checksum, os.path.basename(file_path))
90 def relative_or_absolute_path(file_string: str):
91 """Return a Path relative to os.getcwd(), else an absolute path."""
92 file_path = Path(file_string)
94 return file_path.relative_to(os.getcwd())
96 return file_path.resolve()
99 def download_to_cache(url: str,
100 expected_md5sum=None,
101 expected_sha256sum=None,
102 cache_directory=".cache",
103 downloaded_file_name=None) -> str:
105 cache_dir = os.path.realpath(
106 os.path.expanduser(os.path.expandvars(cache_directory)))
107 if not downloaded_file_name:
108 # Use the last part of the URL as the file name.
109 downloaded_file_name = url.split("/")[-1]
110 downloaded_file = os.path.join(cache_dir, downloaded_file_name)
112 if not os.path.exists(downloaded_file):
113 _LOG.info("Downloading: %s", url)
114 _LOG.info("Please wait...")
115 urllib.request.urlretrieve(url, filename=downloaded_file)
117 if os.path.exists(downloaded_file):
118 _LOG.info("Downloaded: %s", relative_or_absolute_path(downloaded_file))
119 if expected_sha256sum:
120 verify_file_checksum(downloaded_file,
122 sum_function=sha256_sum)
123 elif expected_md5sum:
124 verify_file_checksum(downloaded_file,
126 sum_function=md5_sum)
128 return downloaded_file
131 def extract_zipfile(archive_file: str, dest_dir: str):
132 """Extract a zipfile preseving permissions."""
133 destination_path = Path(dest_dir)
134 with zipfile.ZipFile(archive_file) as archive:
135 for info in archive.infolist():
136 archive.extract(info.filename, path=dest_dir)
137 permissions = info.external_attr >> 16
138 out_path = destination_path / info.filename
139 out_path.chmod(permissions)
142 def extract_tarfile(archive_file: str, dest_dir: str):
143 with tarfile.open(archive_file, 'r') as archive:
144 archive.extractall(path=dest_dir)
147 def extract_archive(archive_file: str,
150 remove_single_toplevel_folder=True):
151 """Extract a tar or zip file.
154 archive_file (str): Absolute path to the archive file.
155 dest_dir (str): Extraction destination directory.
156 cache_dir (str): Directory where temp files can be created.
157 remove_single_toplevel_folder (bool): If the archive contains only a
158 single folder move the contents of that into the destination
161 # Make a temporary directory to extract files into
162 temp_extract_dir = os.path.join(cache_dir,
163 "." + os.path.basename(archive_file))
164 os.makedirs(temp_extract_dir, exist_ok=True)
166 _LOG.info("Extracting: %s", relative_or_absolute_path(archive_file))
167 if zipfile.is_zipfile(archive_file):
168 extract_zipfile(archive_file, temp_extract_dir)
169 elif tarfile.is_tarfile(archive_file):
170 extract_tarfile(archive_file, temp_extract_dir)
172 _LOG.error("Unknown archive format: %s", archive_file)
175 _LOG.info("Installing into: %s", relative_or_absolute_path(dest_dir))
176 path_to_extracted_files = temp_extract_dir
178 extracted_top_level_files = os.listdir(temp_extract_dir)
179 # Check if tarfile has only one folder
180 # If yes, make that the new path_to_extracted_files
181 if remove_single_toplevel_folder and len(extracted_top_level_files) == 1:
182 path_to_extracted_files = os.path.join(temp_extract_dir,
183 extracted_top_level_files[0])
185 # Move extracted files to dest_dir
186 extracted_files = os.listdir(path_to_extracted_files)
187 for file_name in extracted_files:
188 source_file = os.path.join(path_to_extracted_files, file_name)
189 dest_file = os.path.join(dest_dir, file_name)
190 shutil.move(source_file, dest_file)
192 # rm -rf temp_extract_dir
193 shutil.rmtree(temp_extract_dir, ignore_errors=True)
195 # Return List of extracted files
196 return list(Path(dest_dir).rglob("*"))
199 def remove_empty_directories(directory):
200 """Recursively remove empty directories."""
202 for path in sorted(Path(directory).rglob("*"), reverse=True):
204 if path.is_symlink() and not path.exists():
207 elif path.is_dir() and len(os.listdir(path)) == 0:
211 def decode_file_json(file_name):
212 """Decode JSON values from a file.
214 Does not raise an error if the file cannot be decoded."""
216 # Get absolute path to the file.
217 file_path = os.path.realpath(
218 os.path.expanduser(os.path.expandvars(file_name)))
220 json_file_options = {}
222 with open(file_path, "r") as jfile:
223 json_file_options = json.loads(jfile.read())
224 except (FileNotFoundError, json.JSONDecodeError):
225 _LOG.warning("Unable to read file '%s'", file_path)
227 return json_file_options, file_path
230 def git_apply_patch(root_directory,
232 ignore_whitespace=True,
234 """Use `git apply` to apply a diff file."""
236 _LOG.info("Applying Patch: %s", patch_file)
237 git_apply_command = ["git", "apply"]
238 if ignore_whitespace:
239 git_apply_command.append("--ignore-whitespace")
241 git_apply_command.append("--unsafe-paths")
242 git_apply_command += ["--directory", root_directory, patch_file]
243 subprocess.run(git_apply_command)