Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_build / py / pw_build / generated_tests.py
1 # Copyright 2020 The Pigweed Authors
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
5 # the License at
6 #
7 #     https://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
13 # the License.
14 """Tools for generating Pigweed tests that execute in C++ and Python."""
15
16 import argparse
17 from dataclasses import dataclass
18 from datetime import datetime
19 from collections import defaultdict
20 import unittest
21
22 from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
23                     Sequence, TextIO, TypeVar, Union)
24
25 _CPP_HEADER = f"""\
26 // Copyright {datetime.now().year} The Pigweed Authors
27 //
28 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
29 // use this file except in compliance with the License. You may obtain a copy of
30 // the License at
31 //
32 //     https://www.apache.org/licenses/LICENSE-2.0
33 //
34 // Unless required by applicable law or agreed to in writing, software
35 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
36 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
37 // License for the specific language governing permissions and limitations under
38 // the License.
39
40 // AUTOGENERATED - DO NOT EDIT
41 //
42 // Generated at {datetime.now().isoformat()}
43
44 // clang-format off
45 """
46
47
48 class Error(Exception):
49     """Something went wrong when generating tests."""
50
51
52 T = TypeVar('T')
53
54
55 @dataclass
56 class Context(Generic[T]):
57     """Info passed into test generator functions for each test case."""
58     group: str
59     count: int
60     total: int
61     test_case: T
62
63     def cc_name(self) -> str:
64         name = ''.join(w.capitalize()
65                        for w in self.group.replace('-', ' ').split(' '))
66         name = ''.join(c if c.isalnum() else '_' for c in name)
67         return f'{name}_{self.count}' if self.total > 1 else name
68
69     def py_name(self) -> str:
70         name = 'test_' + ''.join(c if c.isalnum() else '_'
71                                  for c in self.group.lower())
72         return f'{name}_{self.count}' if self.total > 1 else name
73
74
75 # Test cases are specified as a sequence of strings or test case instances. The
76 # strings are used to separate the tests into named groups. For example:
77 #
78 #   STR_SPLIT_TEST_CASES = (
79 #     'Empty input',
80 #     MyTestCase('', '', []),
81 #     MyTestCase('', 'foo', []),
82 #     'Split on single character',
83 #     MyTestCase('abcde', 'c', ['ab', 'de']),
84 #     ...
85 #   )
86 #
87 GroupOrTest = Union[str, T]
88
89 # Python tests are generated by a function that returns a function usable as a
90 # unittest.TestCase method.
91 PyTest = Callable[[unittest.TestCase], None]
92 PyTestGenerator = Callable[[Context[T]], PyTest]
93
94 # C++ tests are generated with a function that returns or yields lines of C++
95 # code for the given test case.
96 CcTestGenerator = Callable[[Context[T]], Iterable[str]]
97
98
99 class TestGenerator(Generic[T]):
100     """Generates tests for multiple languages from a series of test cases."""
101     def __init__(self, test_cases: Sequence[GroupOrTest[T]]):
102         self._cases: Dict[str, List[T]] = defaultdict(list)
103         message = ''
104
105         if len(test_cases) < 2:
106             raise Error('At least one test case must be provided')
107
108         if not isinstance(test_cases[0], str):
109             raise Error(
110                 'The first item in the test cases must be a group name string')
111
112         for case in test_cases:
113             if isinstance(case, str):
114                 message = case
115             else:
116                 self._cases[message].append(case)
117
118         if '' in self._cases:
119             raise Error('Empty test group names are not permitted')
120
121     def _test_contexts(self) -> Iterator[Context[T]]:
122         for group, test_list in self._cases.items():
123             for i, test_case in enumerate(test_list, 1):
124                 yield Context(group, i, len(test_list), test_case)
125
126     def _generate_python_tests(self, define_py_test: PyTestGenerator):
127         tests: Dict[str, Callable[[Any], None]] = {}
128
129         for ctx in self._test_contexts():
130             test = define_py_test(ctx)
131             test.__name__ = ctx.py_name()
132
133             if test.__name__ in tests:
134                 raise Error(
135                     f'Multiple Python tests are named {test.__name__}!')
136
137             tests[test.__name__] = test
138
139         return tests
140
141     def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type:
142         """Returns a Python unittest.TestCase class with tests for each case."""
143         return type(name, (unittest.TestCase, ),
144                     self._generate_python_tests(define_py_test))
145
146     def _generate_cc_tests(self, define_cpp_test: CcTestGenerator, header: str,
147                            footer: str) -> Iterator[str]:
148         yield _CPP_HEADER
149         yield header
150
151         for ctx in self._test_contexts():
152             yield from define_cpp_test(ctx)
153             yield ''
154
155         yield footer
156
157     def cc_tests(self, output: TextIO, define_cpp_test: CcTestGenerator,
158                  header: str, footer: str):
159         """Writes C++ unit tests for each test case to the given file."""
160         for line in self._generate_cc_tests(define_cpp_test, header, footer):
161             output.write(line)
162             output.write('\n')
163
164
165 def _to_chars(data: bytes) -> Iterator[str]:
166     for i, byte in enumerate(data):
167         try:
168             char = data[i:i + 1].decode()
169             yield char if char.isprintable() else fr'\x{byte:02x}'
170         except UnicodeDecodeError:
171             yield fr'\x{byte:02x}'
172
173
174 def cc_string(data: Union[str, bytes]) -> str:
175     """Returns a C++ string literal version of a byte string or UTF-8 string."""
176     if isinstance(data, str):
177         data = data.encode()
178
179     return '"' + ''.join(_to_chars(data)) + '"'
180
181
182 def parse_test_generation_args() -> argparse.Namespace:
183     parser = argparse.ArgumentParser(description='Generate unit test files')
184     parser.add_argument('--generate-cc-test',
185                         type=argparse.FileType('w'),
186                         help='Generate the C++ test file')
187     return parser.parse_known_args()[0]