1 # Copyright 2020 The Pigweed Authors
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
7 # https://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
14 """Tools for generating Pigweed tests that execute in C++ and Python."""
17 from dataclasses import dataclass
18 from datetime import datetime
19 from collections import defaultdict
22 from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
23 Sequence, TextIO, TypeVar, Union)
26 // Copyright {datetime.now().year} The Pigweed Authors
28 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
29 // use this file except in compliance with the License. You may obtain a copy of
32 // https://www.apache.org/licenses/LICENSE-2.0
34 // Unless required by applicable law or agreed to in writing, software
35 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
36 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
37 // License for the specific language governing permissions and limitations under
40 // AUTOGENERATED - DO NOT EDIT
42 // Generated at {datetime.now().isoformat()}
48 class Error(Exception):
49 """Something went wrong when generating tests."""
56 class Context(Generic[T]):
57 """Info passed into test generator functions for each test case."""
63 def cc_name(self) -> str:
64 name = ''.join(w.capitalize()
65 for w in self.group.replace('-', ' ').split(' '))
66 name = ''.join(c if c.isalnum() else '_' for c in name)
67 return f'{name}_{self.count}' if self.total > 1 else name
69 def py_name(self) -> str:
70 name = 'test_' + ''.join(c if c.isalnum() else '_'
71 for c in self.group.lower())
72 return f'{name}_{self.count}' if self.total > 1 else name
75 # Test cases are specified as a sequence of strings or test case instances. The
76 # strings are used to separate the tests into named groups. For example:
78 # STR_SPLIT_TEST_CASES = (
80 # MyTestCase('', '', []),
81 # MyTestCase('', 'foo', []),
82 # 'Split on single character',
83 # MyTestCase('abcde', 'c', ['ab', 'de']),
87 GroupOrTest = Union[str, T]
89 # Python tests are generated by a function that returns a function usable as a
90 # unittest.TestCase method.
91 PyTest = Callable[[unittest.TestCase], None]
92 PyTestGenerator = Callable[[Context[T]], PyTest]
94 # C++ tests are generated with a function that returns or yields lines of C++
95 # code for the given test case.
96 CcTestGenerator = Callable[[Context[T]], Iterable[str]]
99 class TestGenerator(Generic[T]):
100 """Generates tests for multiple languages from a series of test cases."""
101 def __init__(self, test_cases: Sequence[GroupOrTest[T]]):
102 self._cases: Dict[str, List[T]] = defaultdict(list)
105 if len(test_cases) < 2:
106 raise Error('At least one test case must be provided')
108 if not isinstance(test_cases[0], str):
110 'The first item in the test cases must be a group name string')
112 for case in test_cases:
113 if isinstance(case, str):
116 self._cases[message].append(case)
118 if '' in self._cases:
119 raise Error('Empty test group names are not permitted')
121 def _test_contexts(self) -> Iterator[Context[T]]:
122 for group, test_list in self._cases.items():
123 for i, test_case in enumerate(test_list, 1):
124 yield Context(group, i, len(test_list), test_case)
126 def _generate_python_tests(self, define_py_test: PyTestGenerator):
127 tests: Dict[str, Callable[[Any], None]] = {}
129 for ctx in self._test_contexts():
130 test = define_py_test(ctx)
131 test.__name__ = ctx.py_name()
133 if test.__name__ in tests:
135 f'Multiple Python tests are named {test.__name__}!')
137 tests[test.__name__] = test
141 def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type:
142 """Returns a Python unittest.TestCase class with tests for each case."""
143 return type(name, (unittest.TestCase, ),
144 self._generate_python_tests(define_py_test))
146 def _generate_cc_tests(self, define_cpp_test: CcTestGenerator, header: str,
147 footer: str) -> Iterator[str]:
151 for ctx in self._test_contexts():
152 yield from define_cpp_test(ctx)
157 def cc_tests(self, output: TextIO, define_cpp_test: CcTestGenerator,
158 header: str, footer: str):
159 """Writes C++ unit tests for each test case to the given file."""
160 for line in self._generate_cc_tests(define_cpp_test, header, footer):
165 def _to_chars(data: bytes) -> Iterator[str]:
166 for i, byte in enumerate(data):
168 char = data[i:i + 1].decode()
169 yield char if char.isprintable() else fr'\x{byte:02x}'
170 except UnicodeDecodeError:
171 yield fr'\x{byte:02x}'
174 def cc_string(data: Union[str, bytes]) -> str:
175 """Returns a C++ string literal version of a byte string or UTF-8 string."""
176 if isinstance(data, str):
179 return '"' + ''.join(_to_chars(data)) + '"'
182 def parse_test_generation_args() -> argparse.Namespace:
183 parser = argparse.ArgumentParser(description='Generate unit test files')
184 parser.add_argument('--generate-cc-test',
185 type=argparse.FileType('w'),
186 help='Generate the C++ test file')
187 return parser.parse_known_args()[0]