1 # Copyright 2020 The Pigweed Authors
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
7 # https://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
14 """Utilities for running unit tests over Pigweed RPC."""
17 from dataclasses import dataclass
19 from typing import Iterable
22 from pw_rpc.callback_client import OptionalTimeout, UseDefault
23 from pw_unit_test_proto import unit_test_pb2
25 _LOG = logging.getLogger(__name__)
28 @dataclass(frozen=True)
34 def __str__(self) -> str:
35 return f'{self.suite_name}.{self.test_name}'
37 def __repr__(self) -> str:
38 return f'TestCase({str(self)})'
41 @dataclass(frozen=True)
42 class TestExpectation:
44 evaluated_expression: str
48 def __str__(self) -> str:
49 return self.expression
51 def __repr__(self) -> str:
52 return f'TestExpectation({str(self)})'
55 class EventHandler(abc.ABC):
57 def run_all_tests_start(self):
58 """Called before all tests are run."""
61 def run_all_tests_end(self, passed_tests: int, failed_tests: int):
62 """Called after the test run is complete."""
65 def test_case_start(self, test_case: TestCase):
66 """Called when a new test case is started."""
69 def test_case_end(self, test_case: TestCase, result: int):
70 """Called when a test case completes with its overall result."""
73 def test_case_disabled(self, test_case: TestCase):
74 """Called when a disabled test case is encountered."""
77 def test_case_expect(self, test_case: TestCase,
78 expectation: TestExpectation):
79 """Called after each expect/assert statement within a test case."""
82 class LoggingEventHandler(EventHandler):
83 """Event handler that logs test events using Google Test format."""
84 def run_all_tests_start(self):
85 _LOG.info('[==========] Running all tests.')
87 def run_all_tests_end(self, passed_tests: int, failed_tests: int):
88 _LOG.info('[==========] Done running all tests.')
89 _LOG.info('[ PASSED ] %d test(s).', passed_tests)
91 _LOG.info('[ FAILED ] %d test(s).', failed_tests)
93 def test_case_start(self, test_case: TestCase):
94 _LOG.info('[ RUN ] %s', test_case)
96 def test_case_end(self, test_case: TestCase, result: int):
97 if result == unit_test_pb2.TestCaseResult.SUCCESS:
98 _LOG.info('[ OK ] %s', test_case)
100 _LOG.info('[ FAILED ] %s', test_case)
102 def test_case_disabled(self, test_case: TestCase):
103 _LOG.info('Skipping disabled test %s', test_case)
105 def test_case_expect(self, test_case: TestCase,
106 expectation: TestExpectation):
107 result = 'Success' if expectation.success else 'Failure'
108 log = _LOG.info if expectation.success else _LOG.error
109 log('%s:%d: %s', test_case.file_name, expectation.line_number, result)
110 log(' Expected: %s', expectation.expression)
111 log(' Actual: %s', expectation.evaluated_expression)
114 def run_tests(rpcs: pw_rpc.client.Services,
115 report_passed_expectations: bool = False,
116 event_handlers: Iterable[EventHandler] = (
117 LoggingEventHandler(), ),
118 timeout_s: OptionalTimeout = UseDefault.VALUE) -> bool:
119 """Runs unit tests on a device over Pigweed RPC.
121 Calls each of the provided event handlers as test events occur, and returns
122 True if all tests pass.
124 unit_test_service = rpcs.pw.unit_test.UnitTest # type: ignore[attr-defined]
126 test_responses = iter(
127 unit_test_service.Run(
128 report_passed_expectations=report_passed_expectations,
129 pw_rpc_timeout_s=timeout_s))
131 # Read the first response, which must be a test_run_start message.
132 first_response = next(test_responses)
133 if not first_response.HasField('test_run_start'):
135 'Expected a "test_run_start" response from pw.unit_test.Run, '
136 'but received a different message type. A response may have been '
139 for event_handler in event_handlers:
140 event_handler.run_all_tests_start()
142 all_tests_passed = False
144 for response in test_responses:
145 if response.HasField('test_case_start'):
146 raw_test_case = response.test_case_start
147 current_test_case = TestCase(raw_test_case.suite_name,
148 raw_test_case.test_name,
149 raw_test_case.file_name)
151 for event_handler in event_handlers:
152 if response.HasField('test_run_start'):
153 event_handler.run_all_tests_start()
154 elif response.HasField('test_run_end'):
155 event_handler.run_all_tests_end(response.test_run_end.passed,
156 response.test_run_end.failed)
157 if response.test_run_end.failed == 0:
158 all_tests_passed = True
159 elif response.HasField('test_case_start'):
160 event_handler.test_case_start(current_test_case)
161 elif response.HasField('test_case_end'):
162 event_handler.test_case_end(current_test_case,
163 response.test_case_end)
164 elif response.HasField('test_case_disabled'):
165 event_handler.test_case_disabled(current_test_case)
166 elif response.HasField('test_case_expectation'):
167 raw_expectation = response.test_case_expectation
168 expectation = TestExpectation(
169 raw_expectation.expression,
170 raw_expectation.evaluated_expression,
171 raw_expectation.line_number,
172 raw_expectation.success,
174 event_handler.test_case_expect(current_test_case, expectation)
176 return all_tests_passed