Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_unit_test / py / pw_unit_test / rpc.py
1 # Copyright 2020 The Pigweed Authors
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
5 # the License at
6 #
7 #     https://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
13 # the License.
14 """Utilities for running unit tests over Pigweed RPC."""
15
16 import abc
17 from dataclasses import dataclass
18 import logging
19 from typing import Iterable
20
21 import pw_rpc.client
22 from pw_rpc.callback_client import OptionalTimeout, UseDefault
23 from pw_unit_test_proto import unit_test_pb2
24
25 _LOG = logging.getLogger(__name__)
26
27
28 @dataclass(frozen=True)
29 class TestCase:
30     suite_name: str
31     test_name: str
32     file_name: str
33
34     def __str__(self) -> str:
35         return f'{self.suite_name}.{self.test_name}'
36
37     def __repr__(self) -> str:
38         return f'TestCase({str(self)})'
39
40
41 @dataclass(frozen=True)
42 class TestExpectation:
43     expression: str
44     evaluated_expression: str
45     line_number: int
46     success: bool
47
48     def __str__(self) -> str:
49         return self.expression
50
51     def __repr__(self) -> str:
52         return f'TestExpectation({str(self)})'
53
54
55 class EventHandler(abc.ABC):
56     @abc.abstractmethod
57     def run_all_tests_start(self):
58         """Called before all tests are run."""
59
60     @abc.abstractmethod
61     def run_all_tests_end(self, passed_tests: int, failed_tests: int):
62         """Called after the test run is complete."""
63
64     @abc.abstractmethod
65     def test_case_start(self, test_case: TestCase):
66         """Called when a new test case is started."""
67
68     @abc.abstractmethod
69     def test_case_end(self, test_case: TestCase, result: int):
70         """Called when a test case completes with its overall result."""
71
72     @abc.abstractmethod
73     def test_case_disabled(self, test_case: TestCase):
74         """Called when a disabled test case is encountered."""
75
76     @abc.abstractmethod
77     def test_case_expect(self, test_case: TestCase,
78                          expectation: TestExpectation):
79         """Called after each expect/assert statement within a test case."""
80
81
82 class LoggingEventHandler(EventHandler):
83     """Event handler that logs test events using Google Test format."""
84     def run_all_tests_start(self):
85         _LOG.info('[==========] Running all tests.')
86
87     def run_all_tests_end(self, passed_tests: int, failed_tests: int):
88         _LOG.info('[==========] Done running all tests.')
89         _LOG.info('[  PASSED  ] %d test(s).', passed_tests)
90         if failed_tests:
91             _LOG.info('[  FAILED  ] %d test(s).', failed_tests)
92
93     def test_case_start(self, test_case: TestCase):
94         _LOG.info('[ RUN      ] %s', test_case)
95
96     def test_case_end(self, test_case: TestCase, result: int):
97         if result == unit_test_pb2.TestCaseResult.SUCCESS:
98             _LOG.info('[       OK ] %s', test_case)
99         else:
100             _LOG.info('[  FAILED  ] %s', test_case)
101
102     def test_case_disabled(self, test_case: TestCase):
103         _LOG.info('Skipping disabled test %s', test_case)
104
105     def test_case_expect(self, test_case: TestCase,
106                          expectation: TestExpectation):
107         result = 'Success' if expectation.success else 'Failure'
108         log = _LOG.info if expectation.success else _LOG.error
109         log('%s:%d: %s', test_case.file_name, expectation.line_number, result)
110         log('      Expected: %s', expectation.expression)
111         log('        Actual: %s', expectation.evaluated_expression)
112
113
114 def run_tests(rpcs: pw_rpc.client.Services,
115               report_passed_expectations: bool = False,
116               event_handlers: Iterable[EventHandler] = (
117                   LoggingEventHandler(), ),
118               timeout_s: OptionalTimeout = UseDefault.VALUE) -> bool:
119     """Runs unit tests on a device over Pigweed RPC.
120
121     Calls each of the provided event handlers as test events occur, and returns
122     True if all tests pass.
123     """
124     unit_test_service = rpcs.pw.unit_test.UnitTest  # type: ignore[attr-defined]
125
126     test_responses = iter(
127         unit_test_service.Run(
128             report_passed_expectations=report_passed_expectations,
129             pw_rpc_timeout_s=timeout_s))
130
131     # Read the first response, which must be a test_run_start message.
132     first_response = next(test_responses)
133     if not first_response.HasField('test_run_start'):
134         raise ValueError(
135             'Expected a "test_run_start" response from pw.unit_test.Run, '
136             'but received a different message type. A response may have been '
137             'dropped.')
138
139     for event_handler in event_handlers:
140         event_handler.run_all_tests_start()
141
142     all_tests_passed = False
143
144     for response in test_responses:
145         if response.HasField('test_case_start'):
146             raw_test_case = response.test_case_start
147             current_test_case = TestCase(raw_test_case.suite_name,
148                                          raw_test_case.test_name,
149                                          raw_test_case.file_name)
150
151         for event_handler in event_handlers:
152             if response.HasField('test_run_start'):
153                 event_handler.run_all_tests_start()
154             elif response.HasField('test_run_end'):
155                 event_handler.run_all_tests_end(response.test_run_end.passed,
156                                                 response.test_run_end.failed)
157                 if response.test_run_end.failed == 0:
158                     all_tests_passed = True
159             elif response.HasField('test_case_start'):
160                 event_handler.test_case_start(current_test_case)
161             elif response.HasField('test_case_end'):
162                 event_handler.test_case_end(current_test_case,
163                                             response.test_case_end)
164             elif response.HasField('test_case_disabled'):
165                 event_handler.test_case_disabled(current_test_case)
166             elif response.HasField('test_case_expectation'):
167                 raw_expectation = response.test_case_expectation
168                 expectation = TestExpectation(
169                     raw_expectation.expression,
170                     raw_expectation.evaluated_expression,
171                     raw_expectation.line_number,
172                     raw_expectation.success,
173                 )
174                 event_handler.test_case_expect(current_test_case, expectation)
175
176     return all_tests_passed