2 # Copyright 2021 The Pigweed Authors
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
5 # use this file except in compliance with the License. You may obtain a copy of
8 # https://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations under
15 """Tests using the callback client for pw_rpc."""
18 from unittest import mock
19 from typing import List, Tuple
21 from pw_protobuf_compiler import python_protos
22 from pw_rpc_protos.internal import packet_pb2
23 from pw_status import Status
25 from pw_rpc import callback_client, client, packets
33 uint32 magic_number = 1;
36 message AnotherMessage {
40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
47 service PublicService {
48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
56 def _rpc(method_stub):
57 return client.PendingRpc(method_stub.channel, method_stub.method.service,
61 class CallbackClientImplTest(unittest.TestCase):
62 """Tests the callback_client as used within a pw_rpc Client."""
64 self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
65 self._request = self._protos.packages.pw.test1.SomeMessage
67 self._client = client.Client.from_modules(
68 callback_client.Impl(), [client.Channel(1, self._handle_request)],
69 self._protos.modules())
70 self._service = self._client.channel(1).rpcs.pw.test1.PublicService
72 self._last_request: packet_pb2.RpcPacket = None
73 self._next_packets: List[Tuple[bytes, Status]] = []
74 self._send_responses_on_request = True
76 def _enqueue_response(self,
79 status: Status = Status.OK,
82 ids: Tuple[int, int] = None,
83 process_status=Status.OK):
86 service_id, method_id = method.service.id, method.id
88 assert ids is not None and method is None
89 service_id, method_id = ids
91 if isinstance(response, bytes):
94 payload = response.SerializeToString()
96 self._next_packets.append(
97 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.RESPONSE,
98 channel_id=channel_id,
99 service_id=service_id,
102 payload=payload).SerializeToString(),
105 def _enqueue_stream_end(self,
108 status: Status = Status.OK,
109 process_status=Status.OK):
110 self._next_packets.append(
111 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_STREAM_END,
112 channel_id=channel_id,
113 service_id=method.service.id,
115 status=status.value).SerializeToString(),
118 def _enqueue_error(self,
122 process_status=Status.OK):
123 self._next_packets.append(
124 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
125 channel_id=channel_id,
126 service_id=method.service.id,
128 status=status.value).SerializeToString(),
131 def _handle_request(self, data: bytes):
132 # Disable this method to prevent infinite recursion if processing the
133 # packet happens to send another packet.
134 if not self._send_responses_on_request:
137 self._send_responses_on_request = False
139 self._last_request = packets.decode(data)
141 for packet, status in self._next_packets:
142 self.assertIs(status, self._client.process_packet(packet))
144 self._next_packets.clear()
145 self._send_responses_on_request = True
147 def _sent_payload(self, message_type):
148 self.assertIsNotNone(self._last_request)
149 message = message_type()
150 message.ParseFromString(self._last_request.payload)
153 def test_invoke_unary_rpc(self):
154 method = self._service.SomeUnary.method
157 self._enqueue_response(1, method, Status.ABORTED,
158 method.response_type(payload='0_o'))
160 status, response = self._service.SomeUnary(
161 method.request_type(magic_number=6))
165 self._sent_payload(method.request_type).magic_number)
167 self.assertIs(Status.ABORTED, status)
168 self.assertEqual('0_o', response.payload)
170 def test_invoke_unary_rpc_with_callback(self):
171 method = self._service.SomeUnary.method
174 self._enqueue_response(1, method, Status.ABORTED,
175 method.response_type(payload='0_o'))
177 callback = mock.Mock()
178 self._service.SomeUnary.invoke(self._request(magic_number=5),
181 callback.assert_has_calls([
182 mock.call(_rpc(self._service.SomeUnary),
183 method.response_type(payload='0_o')),
184 mock.call(_rpc(self._service.SomeUnary), Status.ABORTED)
189 self._sent_payload(method.request_type).magic_number)
191 def test_unary_rpc_server_error(self):
192 method = self._service.SomeUnary.method
195 self._enqueue_error(1, method, Status.NOT_FOUND)
197 with self.assertRaises(callback_client.RpcError) as context:
198 self._service.SomeUnary(method.request_type(magic_number=6))
200 self.assertIs(context.exception.status, Status.NOT_FOUND)
202 def test_invoke_unary_rpc_callback_exceptions_suppressed(self):
203 stub = self._service.SomeUnary
205 self._enqueue_response(1, stub.method)
206 exception_msg = 'YOU BROKE IT O-]-<'
208 with self.assertLogs(callback_client.__name__, 'ERROR') as logs:
209 stub.invoke(self._request(),
210 mock.Mock(side_effect=Exception(exception_msg)))
212 self.assertIn(exception_msg, ''.join(logs.output))
214 # Make sure we can still invoke the RPC.
215 self._enqueue_response(1, stub.method, Status.UNKNOWN)
217 self.assertIs(status, Status.UNKNOWN)
219 def test_invoke_unary_rpc_with_callback_cancel(self):
220 callback = mock.Mock()
223 call = self._service.SomeUnary.invoke(
224 self._request(magic_number=55), callback)
226 self.assertIsNotNone(self._last_request)
227 self._last_request = None
229 # Try to invoke the RPC again before cancelling, without overriding
231 with self.assertRaises(client.Error):
232 self._service.SomeUnary.invoke(self._request(magic_number=56),
234 override_pending=False)
236 self.assertTrue(call.cancel())
237 self.assertFalse(call.cancel()) # Already cancelled, returns False
239 # Unary RPCs do not send a cancel request to the server.
240 self.assertIsNone(self._last_request)
242 callback.assert_not_called()
244 def test_reinvoke_unary_rpc(self):
246 self._last_request = None
247 self._service.SomeUnary.invoke(self._request(magic_number=55),
248 override_pending=True)
249 self.assertEqual(self._last_request.type,
250 packet_pb2.PacketType.REQUEST)
252 def test_invoke_server_streaming(self):
253 method = self._service.SomeServerStreaming.method
255 rep1 = method.response_type(payload='!!!')
256 rep2 = method.response_type(payload='?')
259 self._enqueue_response(1, method, response=rep1)
260 self._enqueue_response(1, method, response=rep2)
261 self._enqueue_stream_end(1, method, Status.ABORTED)
265 list(self._service.SomeServerStreaming(magic_number=4)))
269 self._sent_payload(method.request_type).magic_number)
271 def test_invoke_server_streaming_with_callbacks(self):
272 method = self._service.SomeServerStreaming.method
274 rep1 = method.response_type(payload='!!!')
275 rep2 = method.response_type(payload='?')
278 self._enqueue_response(1, method, response=rep1)
279 self._enqueue_response(1, method, response=rep2)
280 self._enqueue_stream_end(1, method, Status.ABORTED)
282 callback = mock.Mock()
283 self._service.SomeServerStreaming.invoke(
284 self._request(magic_number=3), callback, callback)
286 rpc = _rpc(self._service.SomeServerStreaming)
287 callback.assert_has_calls([
288 mock.call(rpc, method.response_type(payload='!!!')),
289 mock.call(rpc, method.response_type(payload='?')),
290 mock.call(rpc, Status.ABORTED),
295 self._sent_payload(method.request_type).magic_number)
297 def test_invoke_server_streaming_with_callback_cancel(self):
298 stub = self._service.SomeServerStreaming
300 resp = stub.method.response_type(payload='!!!')
301 self._enqueue_response(1, stub.method, response=resp)
303 callback = mock.Mock()
304 call = stub.invoke(self._request(magic_number=3), callback)
305 callback.assert_called_once_with(
306 _rpc(stub), stub.method.response_type(payload='!!!'))
308 callback.reset_mock()
312 self.assertEqual(self._last_request.type,
313 packet_pb2.PacketType.CANCEL_SERVER_STREAM)
315 # Ensure the RPC can be called after being cancelled.
316 self._enqueue_response(1, stub.method, response=resp)
317 self._enqueue_stream_end(1, stub.method, Status.OK)
319 call = stub.invoke(self._request(magic_number=3), callback, callback)
321 callback.assert_has_calls([
322 mock.call(_rpc(stub), stub.method.response_type(payload='!!!')),
323 mock.call(_rpc(stub), Status.OK),
326 def test_ignore_bad_packets_with_pending_rpc(self):
327 method = self._service.SomeUnary.method
328 service_id = method.service.id
331 self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
333 self._enqueue_response(1,
334 ids=(999, method.id),
335 process_status=Status.OK)
337 self._enqueue_response(1,
338 ids=(service_id, 999),
339 process_status=Status.OK)
340 # For RPC not pending (is Status.OK because the packet is processed)
341 self._enqueue_response(1,
343 self._service.SomeBidiStreaming.method.id),
344 process_status=Status.OK)
346 self._enqueue_response(1, method, process_status=Status.OK)
348 status, response = self._service.SomeUnary(magic_number=6)
349 self.assertIs(Status.OK, status)
350 self.assertEqual('', response.payload)
352 def test_pass_none_if_payload_fails_to_decode(self):
353 method = self._service.SomeUnary.method
355 self._enqueue_response(1,
359 process_status=Status.OK)
361 status, response = self._service.SomeUnary(magic_number=6)
362 self.assertIs(status, Status.OK)
363 self.assertIsNone(response)
365 def test_rpc_help_contains_method_name(self):
366 rpc = self._service.SomeUnary
367 self.assertIn(rpc.method.full_name, rpc.help())
369 def test_default_timeouts_set_on_impl(self):
370 impl = callback_client.Impl(None, 1.5)
372 self.assertEqual(impl.default_unary_timeout_s, None)
373 self.assertEqual(impl.default_stream_timeout_s, 1.5)
375 def test_default_timeouts_set_for_all_rpcs(self):
376 rpc_client = client.Client.from_modules(callback_client.Impl(
377 99, 100), [client.Channel(1, lambda *a, **b: None)],
378 self._protos.modules())
379 rpcs = rpc_client.channel(1).rpcs
382 rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99)
384 rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
387 def test_timeout_unary(self):
388 with self.assertRaises(callback_client.RpcTimeout):
389 self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
391 def test_timeout_unary_set_default(self):
392 self._service.SomeUnary.default_timeout_s = 0.0001
394 with self.assertRaises(callback_client.RpcTimeout):
395 self._service.SomeUnary()
397 def test_timeout_server_streaming_iteration(self):
398 responses = self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
399 with self.assertRaises(callback_client.RpcTimeout):
403 def test_timeout_server_streaming_responses(self):
404 responses = self._service.SomeServerStreaming()
405 with self.assertRaises(callback_client.RpcTimeout):
406 for _ in responses.responses(timeout_s=0.0001):
410 if __name__ == '__main__':