3 # Copyright 2020 The Pigweed Authors
5 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
6 # use this file except in compliance with the License. You may obtain a copy of
9 # https://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14 # License for the specific language governing permissions and limitations under
16 """Checks and fixes formatting for source files.
18 This uses clang-format, gn format, gofmt, and python -m yapf to format source
19 code. These tools must be available on the path when this script is invoked!
27 from pathlib import Path
31 from typing import Callable, Collection, Dict, Iterable, List, NamedTuple
32 from typing import Optional, Pattern, Tuple, Union
37 # Append the pw_presubmit package path to the module search path to allow
38 # running this module without installing the pw_presubmit package.
39 sys.path.append(os.path.dirname(os.path.dirname(
40 os.path.abspath(__file__))))
43 from pw_presubmit import cli, git_repo
44 from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural
46 _LOG: logging.Logger = logging.getLogger(__name__)
49 def _colorize_diff_line(line: str) -> str:
50 if line.startswith('--- ') or line.startswith('+++ '):
51 return pw_presubmit.color_bold_white(line)
52 if line.startswith('-'):
53 return pw_presubmit.color_red(line)
54 if line.startswith('+'):
55 return pw_presubmit.color_green(line)
56 if line.startswith('@@ '):
57 return pw_presubmit.color_aqua(line)
61 def colorize_diff(lines: Iterable[str]) -> str:
62 """Takes a diff str or list of str lines and returns a colorized version."""
63 if isinstance(lines, str):
64 lines = lines.splitlines(True)
66 return ''.join(_colorize_diff_line(line) for line in lines)
69 def _diff(path, original: bytes, formatted: bytes) -> str:
72 original.decode(errors='replace').splitlines(True),
73 formatted.decode(errors='replace').splitlines(True),
74 f'{path} (original)', f'{path} (reformatted)'))
77 Formatter = Callable[[str, bytes], bytes]
80 def _diff_formatted(path, formatter: Formatter) -> Optional[str]:
81 """Returns a diff comparing a file to its formatted version."""
82 with open(path, 'rb') as fd:
85 formatted = formatter(path, original)
87 return None if formatted == original else _diff(path, original, formatted)
90 def _check_files(files, formatter: Formatter) -> Dict[Path, str]:
94 difference = _diff_formatted(path, formatter)
96 errors[path] = difference
101 def _clang_format(*args: str, **kwargs) -> bytes:
102 return log_run(['clang-format', '--style=file', *args],
103 stdout=subprocess.PIPE,
108 def clang_format_check(files: Iterable[Path]) -> Dict[Path, str]:
109 """Checks formatting; returns {path: diff} for files with bad formatting."""
110 return _check_files(files, lambda path, _: _clang_format(path))
113 def clang_format_fix(files: Iterable) -> None:
114 """Fixes formatting for the provided files in place."""
115 _clang_format('-i', *files)
118 def check_gn_format(files: Iterable[Path]) -> Dict[Path, str]:
119 """Checks formatting; returns {path: diff} for files with bad formatting."""
121 files, lambda _, data: log_run(['gn', 'format', '--stdin'],
123 stdout=subprocess.PIPE,
127 def fix_gn_format(files: Iterable[Path]) -> None:
128 """Fixes formatting for the provided files in place."""
129 log_run(['gn', 'format', *files], check=True)
132 def check_go_format(files: Iterable[Path]) -> Dict[Path, str]:
133 """Checks formatting; returns {path: diff} for files with bad formatting."""
135 files, lambda path, _: log_run(
136 ['gofmt', path], stdout=subprocess.PIPE, check=True).stdout)
139 def fix_go_format(files: Iterable[Path]) -> None:
140 """Fixes formatting for the provided files in place."""
141 log_run(['gofmt', '-w', *files], check=True)
144 def _yapf(*args, **kwargs) -> subprocess.CompletedProcess:
145 return log_run(['python', '-m', 'yapf', '--parallel', *args],
150 _DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE)
153 def check_py_format(files: Iterable[Path]) -> Dict[Path, str]:
154 """Checks formatting; returns {path: diff} for files with bad formatting."""
155 process = _yapf('--diff', *files)
157 errors: Dict[Path, str] = {}
160 raw_diff = process.stdout.decode(errors='replace')
162 matches = tuple(_DIFF_START.finditer(raw_diff))
163 for start, end in zip(matches, (*matches[1:], None)):
164 errors[Path(start.group(1))] = colorize_diff(
165 raw_diff[start.start():end.start() if end else None])
168 _LOG.error('yapf encountered an error:\n%s',
169 process.stderr.decode(errors='replace').rstrip())
170 errors.update({file: '' for file in files if file not in errors})
175 def fix_py_format(files: Iterable):
176 """Fixes formatting for the provided files in place."""
177 _yapf('--in-place', *files, check=True)
180 _TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE)
183 def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]:
184 """Checks for and optionally removes trailing whitespace."""
188 with path.open('rb') as fd:
191 corrected = _TRAILING_SPACE.sub(b'', contents)
192 if corrected != contents:
193 errors[path] = _diff(path, contents, corrected)
196 with path.open('wb') as fd:
202 def check_trailing_space(files: Iterable[Path]) -> Dict[Path, str]:
203 return _check_trailing_space(files, fix=False)
206 def fix_trailing_space(files: Iterable[Path]) -> None:
207 _check_trailing_space(files, fix=True)
210 def print_format_check(errors: Dict[Path, str],
211 show_fix_commands: bool) -> None:
212 """Prints and returns the result of a check_*_format function."""
214 # Don't print anything in the all-good case.
217 # Show the format fixing diff suggested by the tooling (with colors).
218 _LOG.warning('Found %d files with formatting errors. Format changes:',
220 for diff in errors.values():
223 # Show a copy-and-pastable command to fix the issues.
224 if show_fix_commands:
226 def path_relative_to_cwd(path):
228 return Path(path).resolve().relative_to(Path.cwd().resolve())
230 return Path(path).resolve()
232 message = (f' pw format --fix {path_relative_to_cwd(path)}'
234 _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message))
237 class CodeFormat(NamedTuple):
239 extensions: Collection[str]
240 exclude: Collection[str]
241 check: Callable[[Iterable], Dict[Path, str]]
242 fix: Callable[[Iterable], None]
245 C_FORMAT: CodeFormat = CodeFormat(
247 frozenset(['.h', '.hh', '.hpp', '.c', '.cc', '.cpp', '.inc', '.inl']),
248 (r'\.pb\.h$', r'\.pb\.c$'), clang_format_check, clang_format_fix)
250 PROTO_FORMAT: CodeFormat = CodeFormat('Protocol buffer', ('.proto', ), (),
251 clang_format_check, clang_format_fix)
253 JAVA_FORMAT: CodeFormat = CodeFormat('Java', ('.java', ), (),
254 clang_format_check, clang_format_fix)
256 JAVASCRIPT_FORMAT: CodeFormat = CodeFormat('JavaScript', ('.js', ), (),
260 GO_FORMAT: CodeFormat = CodeFormat('Go', ('.go', ), (), check_go_format,
263 PYTHON_FORMAT: CodeFormat = CodeFormat('Python', ('.py', ), (),
264 check_py_format, fix_py_format)
266 GN_FORMAT: CodeFormat = CodeFormat('GN', ('.gn', '.gni'), (), check_gn_format,
269 # TODO(pwbug/191): Add real code formatting support for Bazel and CMake
270 BAZEL_FORMAT: CodeFormat = CodeFormat('Bazel', ('BUILD', ), (),
271 check_trailing_space, fix_trailing_space)
273 CMAKE_FORMAT: CodeFormat = CodeFormat('CMake', ('CMakeLists.txt', '.cmake'),
274 (), check_trailing_space,
277 RST_FORMAT: CodeFormat = CodeFormat('reStructuredText', ('.rst', ), (),
278 check_trailing_space, fix_trailing_space)
280 MARKDOWN_FORMAT: CodeFormat = CodeFormat('Markdown', ('.md', ), (),
281 check_trailing_space,
284 CODE_FORMATS: Tuple[CodeFormat, ...] = (
299 def presubmit_check(code_format: CodeFormat, **filter_paths_args) -> Callable:
300 """Creates a presubmit check function from a CodeFormat object."""
301 filter_paths_args.setdefault('endswith', code_format.extensions)
302 filter_paths_args.setdefault('exclude', code_format.exclude)
304 @pw_presubmit.filter_paths(**filter_paths_args)
305 def check_code_format(ctx: pw_presubmit.PresubmitContext):
306 errors = code_format.check(ctx.paths)
309 # When running as part of presubmit, show the fix command help.
310 show_fix_commands=True,
313 raise pw_presubmit.PresubmitFailure
315 language = code_format.language.lower().replace('+', 'p').replace(' ', '_')
316 check_code_format.__name__ = f'{language}_format'
318 return check_code_format
321 def presubmit_checks(**filter_paths_args) -> Tuple[Callable, ...]:
322 """Returns a tuple with all supported code format presubmit checks."""
324 presubmit_check(fmt, **filter_paths_args) for fmt in CODE_FORMATS)
328 """Checks or fixes the formatting of a set of files."""
329 def __init__(self, files: Iterable[Path]):
330 self.paths = list(files)
331 self._formats: Dict[CodeFormat, List] = collections.defaultdict(list)
333 for path in self.paths:
334 for code_format in CODE_FORMATS:
335 if any(path.as_posix().endswith(e)
336 for e in code_format.extensions):
337 self._formats[code_format].append(path)
339 def check(self) -> Dict[Path, str]:
340 """Returns {path: diff} for files with incorrect formatting."""
341 errors: Dict[Path, str] = {}
343 for code_format, files in self._formats.items():
344 _LOG.debug('Checking %s', ', '.join(str(f) for f in files))
345 errors.update(code_format.check(files))
347 return collections.OrderedDict(sorted(errors.items()))
349 def fix(self) -> None:
350 """Fixes format errors for supported files in place."""
351 for code_format, files in self._formats.items():
352 code_format.fix(files)
353 _LOG.info('Formatted %s',
354 plural(files, code_format.language + ' file'))
357 def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]:
360 Path(f).resolve().relative_to(base.resolve()) for f in files)
365 def format_paths_in_repo(paths: Collection[Union[Path, str]],
366 exclude: Collection[Pattern[str]], fix: bool,
368 """Checks or fixes formatting for files in a Git repo."""
369 files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
370 repo = git_repo.root() if git_repo.is_repo() else None
372 # If this is a Git repo, list the original paths with git ls-files or diff.
376 git_repo.describe_files(repo, Path.cwd(), base, paths, exclude))
378 # Add files from Git and remove duplicates.
380 set(exclude_paths(exclude, git_repo.list_files(base, paths)))
384 'A base commit may only be provided if running from a Git repo')
387 return format_files(files, fix, repo=repo)
390 def format_files(paths: Collection[Union[Path, str]],
392 repo: Optional[Path] = None) -> int:
393 """Checks or fixes formatting for the specified files."""
394 formatter = CodeFormatter(Path(p) for p in paths)
396 _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file'))
398 for line in _file_summary(paths, repo if repo else Path.cwd()):
399 print(line, file=sys.stderr)
401 errors = formatter.check()
402 print_format_check(errors, show_fix_commands=(not fix))
407 # TODO: This should perhaps check that the fixes were successful.
408 _LOG.info('Formatting fixes applied successfully')
411 _LOG.error('Formatting errors found')
414 _LOG.info('Congratulations! No formatting changes needed')
418 def arguments(git_paths: bool) -> argparse.ArgumentParser:
419 """Creates an argument parser for format_files or format_paths_in_repo."""
421 parser = argparse.ArgumentParser(description=__doc__)
424 cli.add_path_arguments(parser)
427 def existing_path(arg: str) -> Path:
429 if not path.is_file():
430 raise argparse.ArgumentTypeError(
431 f'{arg} is not a path to a file')
435 parser.add_argument('paths',
439 help='File paths to check')
441 parser.add_argument('--fix',
443 help='Apply formatting fixes in place.')
448 """Check and fix formatting for source files."""
449 return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args()))
452 if __name__ == '__main__':
454 # If pw_cli is available, use it to initialize logs.
455 from pw_cli import log
457 log.install(logging.INFO)
459 # If pw_cli isn't available, display log messages like a simple print.
460 logging.basicConfig(format='%(message)s', level=logging.INFO)