Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_presubmit / py / pw_presubmit / format_code.py
1 #!/usr/bin/env python3
2
3 # Copyright 2020 The Pigweed Authors
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
6 # use this file except in compliance with the License. You may obtain a copy of
7 # the License at
8 #
9 #     https://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14 # License for the specific language governing permissions and limitations under
15 # the License.
16 """Checks and fixes formatting for source files.
17
18 This uses clang-format, gn format, gofmt, and python -m yapf to format source
19 code. These tools must be available on the path when this script is invoked!
20 """
21
22 import argparse
23 import collections
24 import difflib
25 import logging
26 import os
27 from pathlib import Path
28 import re
29 import subprocess
30 import sys
31 from typing import Callable, Collection, Dict, Iterable, List, NamedTuple
32 from typing import Optional, Pattern, Tuple, Union
33
34 try:
35     import pw_presubmit
36 except ImportError:
37     # Append the pw_presubmit package path to the module search path to allow
38     # running this module without installing the pw_presubmit package.
39     sys.path.append(os.path.dirname(os.path.dirname(
40         os.path.abspath(__file__))))
41     import pw_presubmit
42
43 from pw_presubmit import cli, git_repo
44 from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural
45
46 _LOG: logging.Logger = logging.getLogger(__name__)
47
48
49 def _colorize_diff_line(line: str) -> str:
50     if line.startswith('--- ') or line.startswith('+++ '):
51         return pw_presubmit.color_bold_white(line)
52     if line.startswith('-'):
53         return pw_presubmit.color_red(line)
54     if line.startswith('+'):
55         return pw_presubmit.color_green(line)
56     if line.startswith('@@ '):
57         return pw_presubmit.color_aqua(line)
58     return line
59
60
61 def colorize_diff(lines: Iterable[str]) -> str:
62     """Takes a diff str or list of str lines and returns a colorized version."""
63     if isinstance(lines, str):
64         lines = lines.splitlines(True)
65
66     return ''.join(_colorize_diff_line(line) for line in lines)
67
68
69 def _diff(path, original: bytes, formatted: bytes) -> str:
70     return colorize_diff(
71         difflib.unified_diff(
72             original.decode(errors='replace').splitlines(True),
73             formatted.decode(errors='replace').splitlines(True),
74             f'{path}  (original)', f'{path}  (reformatted)'))
75
76
77 Formatter = Callable[[str, bytes], bytes]
78
79
80 def _diff_formatted(path, formatter: Formatter) -> Optional[str]:
81     """Returns a diff comparing a file to its formatted version."""
82     with open(path, 'rb') as fd:
83         original = fd.read()
84
85     formatted = formatter(path, original)
86
87     return None if formatted == original else _diff(path, original, formatted)
88
89
90 def _check_files(files, formatter: Formatter) -> Dict[Path, str]:
91     errors = {}
92
93     for path in files:
94         difference = _diff_formatted(path, formatter)
95         if difference:
96             errors[path] = difference
97
98     return errors
99
100
101 def _clang_format(*args: str, **kwargs) -> bytes:
102     return log_run(['clang-format', '--style=file', *args],
103                    stdout=subprocess.PIPE,
104                    check=True,
105                    **kwargs).stdout
106
107
108 def clang_format_check(files: Iterable[Path]) -> Dict[Path, str]:
109     """Checks formatting; returns {path: diff} for files with bad formatting."""
110     return _check_files(files, lambda path, _: _clang_format(path))
111
112
113 def clang_format_fix(files: Iterable) -> None:
114     """Fixes formatting for the provided files in place."""
115     _clang_format('-i', *files)
116
117
118 def check_gn_format(files: Iterable[Path]) -> Dict[Path, str]:
119     """Checks formatting; returns {path: diff} for files with bad formatting."""
120     return _check_files(
121         files, lambda _, data: log_run(['gn', 'format', '--stdin'],
122                                        input=data,
123                                        stdout=subprocess.PIPE,
124                                        check=True).stdout)
125
126
127 def fix_gn_format(files: Iterable[Path]) -> None:
128     """Fixes formatting for the provided files in place."""
129     log_run(['gn', 'format', *files], check=True)
130
131
132 def check_go_format(files: Iterable[Path]) -> Dict[Path, str]:
133     """Checks formatting; returns {path: diff} for files with bad formatting."""
134     return _check_files(
135         files, lambda path, _: log_run(
136             ['gofmt', path], stdout=subprocess.PIPE, check=True).stdout)
137
138
139 def fix_go_format(files: Iterable[Path]) -> None:
140     """Fixes formatting for the provided files in place."""
141     log_run(['gofmt', '-w', *files], check=True)
142
143
144 def _yapf(*args, **kwargs) -> subprocess.CompletedProcess:
145     return log_run(['python', '-m', 'yapf', '--parallel', *args],
146                    capture_output=True,
147                    **kwargs)
148
149
150 _DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE)
151
152
153 def check_py_format(files: Iterable[Path]) -> Dict[Path, str]:
154     """Checks formatting; returns {path: diff} for files with bad formatting."""
155     process = _yapf('--diff', *files)
156
157     errors: Dict[Path, str] = {}
158
159     if process.stdout:
160         raw_diff = process.stdout.decode(errors='replace')
161
162         matches = tuple(_DIFF_START.finditer(raw_diff))
163         for start, end in zip(matches, (*matches[1:], None)):
164             errors[Path(start.group(1))] = colorize_diff(
165                 raw_diff[start.start():end.start() if end else None])
166
167     if process.stderr:
168         _LOG.error('yapf encountered an error:\n%s',
169                    process.stderr.decode(errors='replace').rstrip())
170         errors.update({file: '' for file in files if file not in errors})
171
172     return errors
173
174
175 def fix_py_format(files: Iterable):
176     """Fixes formatting for the provided files in place."""
177     _yapf('--in-place', *files, check=True)
178
179
180 _TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE)
181
182
183 def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]:
184     """Checks for and optionally removes trailing whitespace."""
185     errors = {}
186
187     for path in paths:
188         with path.open('rb') as fd:
189             contents = fd.read()
190
191         corrected = _TRAILING_SPACE.sub(b'', contents)
192         if corrected != contents:
193             errors[path] = _diff(path, contents, corrected)
194
195             if fix:
196                 with path.open('wb') as fd:
197                     fd.write(corrected)
198
199     return errors
200
201
202 def check_trailing_space(files: Iterable[Path]) -> Dict[Path, str]:
203     return _check_trailing_space(files, fix=False)
204
205
206 def fix_trailing_space(files: Iterable[Path]) -> None:
207     _check_trailing_space(files, fix=True)
208
209
210 def print_format_check(errors: Dict[Path, str],
211                        show_fix_commands: bool) -> None:
212     """Prints and returns the result of a check_*_format function."""
213     if not errors:
214         # Don't print anything in the all-good case.
215         return
216
217     # Show the format fixing diff suggested by the tooling (with colors).
218     _LOG.warning('Found %d files with formatting errors. Format changes:',
219                  len(errors))
220     for diff in errors.values():
221         print(diff, end='')
222
223     # Show a copy-and-pastable command to fix the issues.
224     if show_fix_commands:
225
226         def path_relative_to_cwd(path):
227             try:
228                 return Path(path).resolve().relative_to(Path.cwd().resolve())
229             except ValueError:
230                 return Path(path).resolve()
231
232         message = (f'  pw format --fix {path_relative_to_cwd(path)}'
233                    for path in errors)
234         _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message))
235
236
237 class CodeFormat(NamedTuple):
238     language: str
239     extensions: Collection[str]
240     exclude: Collection[str]
241     check: Callable[[Iterable], Dict[Path, str]]
242     fix: Callable[[Iterable], None]
243
244
245 C_FORMAT: CodeFormat = CodeFormat(
246     'C and C++',
247     frozenset(['.h', '.hh', '.hpp', '.c', '.cc', '.cpp', '.inc', '.inl']),
248     (r'\.pb\.h$', r'\.pb\.c$'), clang_format_check, clang_format_fix)
249
250 PROTO_FORMAT: CodeFormat = CodeFormat('Protocol buffer', ('.proto', ), (),
251                                       clang_format_check, clang_format_fix)
252
253 JAVA_FORMAT: CodeFormat = CodeFormat('Java', ('.java', ), (),
254                                      clang_format_check, clang_format_fix)
255
256 JAVASCRIPT_FORMAT: CodeFormat = CodeFormat('JavaScript', ('.js', ), (),
257                                            clang_format_check,
258                                            clang_format_fix)
259
260 GO_FORMAT: CodeFormat = CodeFormat('Go', ('.go', ), (), check_go_format,
261                                    fix_go_format)
262
263 PYTHON_FORMAT: CodeFormat = CodeFormat('Python', ('.py', ), (),
264                                        check_py_format, fix_py_format)
265
266 GN_FORMAT: CodeFormat = CodeFormat('GN', ('.gn', '.gni'), (), check_gn_format,
267                                    fix_gn_format)
268
269 # TODO(pwbug/191): Add real code formatting support for Bazel and CMake
270 BAZEL_FORMAT: CodeFormat = CodeFormat('Bazel', ('BUILD', ), (),
271                                       check_trailing_space, fix_trailing_space)
272
273 CMAKE_FORMAT: CodeFormat = CodeFormat('CMake', ('CMakeLists.txt', '.cmake'),
274                                       (), check_trailing_space,
275                                       fix_trailing_space)
276
277 RST_FORMAT: CodeFormat = CodeFormat('reStructuredText', ('.rst', ), (),
278                                     check_trailing_space, fix_trailing_space)
279
280 MARKDOWN_FORMAT: CodeFormat = CodeFormat('Markdown', ('.md', ), (),
281                                          check_trailing_space,
282                                          fix_trailing_space)
283
284 CODE_FORMATS: Tuple[CodeFormat, ...] = (
285     C_FORMAT,
286     JAVA_FORMAT,
287     JAVASCRIPT_FORMAT,
288     PROTO_FORMAT,
289     GO_FORMAT,
290     PYTHON_FORMAT,
291     GN_FORMAT,
292     BAZEL_FORMAT,
293     CMAKE_FORMAT,
294     RST_FORMAT,
295     MARKDOWN_FORMAT,
296 )
297
298
299 def presubmit_check(code_format: CodeFormat, **filter_paths_args) -> Callable:
300     """Creates a presubmit check function from a CodeFormat object."""
301     filter_paths_args.setdefault('endswith', code_format.extensions)
302     filter_paths_args.setdefault('exclude', code_format.exclude)
303
304     @pw_presubmit.filter_paths(**filter_paths_args)
305     def check_code_format(ctx: pw_presubmit.PresubmitContext):
306         errors = code_format.check(ctx.paths)
307         print_format_check(
308             errors,
309             # When running as part of presubmit, show the fix command help.
310             show_fix_commands=True,
311         )
312         if errors:
313             raise pw_presubmit.PresubmitFailure
314
315     language = code_format.language.lower().replace('+', 'p').replace(' ', '_')
316     check_code_format.__name__ = f'{language}_format'
317
318     return check_code_format
319
320
321 def presubmit_checks(**filter_paths_args) -> Tuple[Callable, ...]:
322     """Returns a tuple with all supported code format presubmit checks."""
323     return tuple(
324         presubmit_check(fmt, **filter_paths_args) for fmt in CODE_FORMATS)
325
326
327 class CodeFormatter:
328     """Checks or fixes the formatting of a set of files."""
329     def __init__(self, files: Iterable[Path]):
330         self.paths = list(files)
331         self._formats: Dict[CodeFormat, List] = collections.defaultdict(list)
332
333         for path in self.paths:
334             for code_format in CODE_FORMATS:
335                 if any(path.as_posix().endswith(e)
336                        for e in code_format.extensions):
337                     self._formats[code_format].append(path)
338
339     def check(self) -> Dict[Path, str]:
340         """Returns {path: diff} for files with incorrect formatting."""
341         errors: Dict[Path, str] = {}
342
343         for code_format, files in self._formats.items():
344             _LOG.debug('Checking %s', ', '.join(str(f) for f in files))
345             errors.update(code_format.check(files))
346
347         return collections.OrderedDict(sorted(errors.items()))
348
349     def fix(self) -> None:
350         """Fixes format errors for supported files in place."""
351         for code_format, files in self._formats.items():
352             code_format.fix(files)
353             _LOG.info('Formatted %s',
354                       plural(files, code_format.language + ' file'))
355
356
357 def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]:
358     try:
359         return file_summary(
360             Path(f).resolve().relative_to(base.resolve()) for f in files)
361     except ValueError:
362         return []
363
364
365 def format_paths_in_repo(paths: Collection[Union[Path, str]],
366                          exclude: Collection[Pattern[str]], fix: bool,
367                          base: str) -> int:
368     """Checks or fixes formatting for files in a Git repo."""
369     files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
370     repo = git_repo.root() if git_repo.is_repo() else None
371
372     # If this is a Git repo, list the original paths with git ls-files or diff.
373     if repo:
374         _LOG.info(
375             'Formatting %s',
376             git_repo.describe_files(repo, Path.cwd(), base, paths, exclude))
377
378         # Add files from Git and remove duplicates.
379         files = sorted(
380             set(exclude_paths(exclude, git_repo.list_files(base, paths)))
381             | set(files))
382     elif base:
383         _LOG.critical(
384             'A base commit may only be provided if running from a Git repo')
385         return 1
386
387     return format_files(files, fix, repo=repo)
388
389
390 def format_files(paths: Collection[Union[Path, str]],
391                  fix: bool,
392                  repo: Optional[Path] = None) -> int:
393     """Checks or fixes formatting for the specified files."""
394     formatter = CodeFormatter(Path(p) for p in paths)
395
396     _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file'))
397
398     for line in _file_summary(paths, repo if repo else Path.cwd()):
399         print(line, file=sys.stderr)
400
401     errors = formatter.check()
402     print_format_check(errors, show_fix_commands=(not fix))
403
404     if errors:
405         if fix:
406             formatter.fix()
407             # TODO: This should perhaps check that the fixes were successful.
408             _LOG.info('Formatting fixes applied successfully')
409             return 0
410
411         _LOG.error('Formatting errors found')
412         return 1
413
414     _LOG.info('Congratulations! No formatting changes needed')
415     return 0
416
417
418 def arguments(git_paths: bool) -> argparse.ArgumentParser:
419     """Creates an argument parser for format_files or format_paths_in_repo."""
420
421     parser = argparse.ArgumentParser(description=__doc__)
422
423     if git_paths:
424         cli.add_path_arguments(parser)
425     else:
426
427         def existing_path(arg: str) -> Path:
428             path = Path(arg)
429             if not path.is_file():
430                 raise argparse.ArgumentTypeError(
431                     f'{arg} is not a path to a file')
432
433             return path
434
435         parser.add_argument('paths',
436                             metavar='path',
437                             nargs='+',
438                             type=existing_path,
439                             help='File paths to check')
440
441     parser.add_argument('--fix',
442                         action='store_true',
443                         help='Apply formatting fixes in place.')
444     return parser
445
446
447 def main() -> int:
448     """Check and fix formatting for source files."""
449     return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args()))
450
451
452 if __name__ == '__main__':
453     try:
454         # If pw_cli is available, use it to initialize logs.
455         from pw_cli import log
456
457         log.install(logging.INFO)
458     except ImportError:
459         # If pw_cli isn't available, display log messages like a simple print.
460         logging.basicConfig(format='%(message)s', level=logging.INFO)
461
462     sys.exit(main())