Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_rpc / py / callback_client_test.py
1 #!/usr/bin/env python3
2 # Copyright 2021 The Pigweed Authors
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
5 # use this file except in compliance with the License. You may obtain a copy of
6 # the License at
7 #
8 #     https://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations under
14 # the License.
15 """Tests using the callback client for pw_rpc."""
16
17 import unittest
18 from unittest import mock
19 from typing import List, Tuple
20
21 from pw_protobuf_compiler import python_protos
22 from pw_rpc_protos.internal import packet_pb2
23 from pw_status import Status
24
25 from pw_rpc import callback_client, client, packets
26
27 TEST_PROTO_1 = """\
28 syntax = "proto3";
29
30 package pw.test1;
31
32 message SomeMessage {
33   uint32 magic_number = 1;
34 }
35
36 message AnotherMessage {
37   enum Result {
38     FAILED = 0;
39     FAILED_MISERABLY = 1;
40     I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41   }
42
43   Result result = 1;
44   string payload = 2;
45 }
46
47 service PublicService {
48   rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49   rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50   rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51   rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52 }
53 """
54
55
56 def _rpc(method_stub):
57     return client.PendingRpc(method_stub.channel, method_stub.method.service,
58                              method_stub.method)
59
60
61 class CallbackClientImplTest(unittest.TestCase):
62     """Tests the callback_client as used within a pw_rpc Client."""
63     def setUp(self):
64         self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
65         self._request = self._protos.packages.pw.test1.SomeMessage
66
67         self._client = client.Client.from_modules(
68             callback_client.Impl(), [client.Channel(1, self._handle_request)],
69             self._protos.modules())
70         self._service = self._client.channel(1).rpcs.pw.test1.PublicService
71
72         self._last_request: packet_pb2.RpcPacket = None
73         self._next_packets: List[Tuple[bytes, Status]] = []
74         self._send_responses_on_request = True
75
76     def _enqueue_response(self,
77                           channel_id: int,
78                           method=None,
79                           status: Status = Status.OK,
80                           response=b'',
81                           *,
82                           ids: Tuple[int, int] = None,
83                           process_status=Status.OK):
84         if method:
85             assert ids is None
86             service_id, method_id = method.service.id, method.id
87         else:
88             assert ids is not None and method is None
89             service_id, method_id = ids
90
91         if isinstance(response, bytes):
92             payload = response
93         else:
94             payload = response.SerializeToString()
95
96         self._next_packets.append(
97             (packet_pb2.RpcPacket(type=packet_pb2.PacketType.RESPONSE,
98                                   channel_id=channel_id,
99                                   service_id=service_id,
100                                   method_id=method_id,
101                                   status=status.value,
102                                   payload=payload).SerializeToString(),
103              process_status))
104
105     def _enqueue_stream_end(self,
106                             channel_id: int,
107                             method,
108                             status: Status = Status.OK,
109                             process_status=Status.OK):
110         self._next_packets.append(
111             (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_STREAM_END,
112                                   channel_id=channel_id,
113                                   service_id=method.service.id,
114                                   method_id=method.id,
115                                   status=status.value).SerializeToString(),
116              process_status))
117
118     def _enqueue_error(self,
119                        channel_id: int,
120                        method,
121                        status: Status,
122                        process_status=Status.OK):
123         self._next_packets.append(
124             (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
125                                   channel_id=channel_id,
126                                   service_id=method.service.id,
127                                   method_id=method.id,
128                                   status=status.value).SerializeToString(),
129              process_status))
130
131     def _handle_request(self, data: bytes):
132         # Disable this method to prevent infinite recursion if processing the
133         # packet happens to send another packet.
134         if not self._send_responses_on_request:
135             return
136
137         self._send_responses_on_request = False
138
139         self._last_request = packets.decode(data)
140
141         for packet, status in self._next_packets:
142             self.assertIs(status, self._client.process_packet(packet))
143
144         self._next_packets.clear()
145         self._send_responses_on_request = True
146
147     def _sent_payload(self, message_type):
148         self.assertIsNotNone(self._last_request)
149         message = message_type()
150         message.ParseFromString(self._last_request.payload)
151         return message
152
153     def test_invoke_unary_rpc(self):
154         method = self._service.SomeUnary.method
155
156         for _ in range(3):
157             self._enqueue_response(1, method, Status.ABORTED,
158                                    method.response_type(payload='0_o'))
159
160             status, response = self._service.SomeUnary(
161                 method.request_type(magic_number=6))
162
163             self.assertEqual(
164                 6,
165                 self._sent_payload(method.request_type).magic_number)
166
167             self.assertIs(Status.ABORTED, status)
168             self.assertEqual('0_o', response.payload)
169
170     def test_invoke_unary_rpc_with_callback(self):
171         method = self._service.SomeUnary.method
172
173         for _ in range(3):
174             self._enqueue_response(1, method, Status.ABORTED,
175                                    method.response_type(payload='0_o'))
176
177             callback = mock.Mock()
178             self._service.SomeUnary.invoke(self._request(magic_number=5),
179                                            callback, callback)
180
181             callback.assert_has_calls([
182                 mock.call(_rpc(self._service.SomeUnary),
183                           method.response_type(payload='0_o')),
184                 mock.call(_rpc(self._service.SomeUnary), Status.ABORTED)
185             ])
186
187             self.assertEqual(
188                 5,
189                 self._sent_payload(method.request_type).magic_number)
190
191     def test_unary_rpc_server_error(self):
192         method = self._service.SomeUnary.method
193
194         for _ in range(3):
195             self._enqueue_error(1, method, Status.NOT_FOUND)
196
197             with self.assertRaises(callback_client.RpcError) as context:
198                 self._service.SomeUnary(method.request_type(magic_number=6))
199
200             self.assertIs(context.exception.status, Status.NOT_FOUND)
201
202     def test_invoke_unary_rpc_callback_exceptions_suppressed(self):
203         stub = self._service.SomeUnary
204
205         self._enqueue_response(1, stub.method)
206         exception_msg = 'YOU BROKE IT O-]-<'
207
208         with self.assertLogs(callback_client.__name__, 'ERROR') as logs:
209             stub.invoke(self._request(),
210                         mock.Mock(side_effect=Exception(exception_msg)))
211
212         self.assertIn(exception_msg, ''.join(logs.output))
213
214         # Make sure we can still invoke the RPC.
215         self._enqueue_response(1, stub.method, Status.UNKNOWN)
216         status, _ = stub()
217         self.assertIs(status, Status.UNKNOWN)
218
219     def test_invoke_unary_rpc_with_callback_cancel(self):
220         callback = mock.Mock()
221
222         for _ in range(3):
223             call = self._service.SomeUnary.invoke(
224                 self._request(magic_number=55), callback)
225
226             self.assertIsNotNone(self._last_request)
227             self._last_request = None
228
229             # Try to invoke the RPC again before cancelling, without overriding
230             # pending RPCs.
231             with self.assertRaises(client.Error):
232                 self._service.SomeUnary.invoke(self._request(magic_number=56),
233                                                callback,
234                                                override_pending=False)
235
236             self.assertTrue(call.cancel())
237             self.assertFalse(call.cancel())  # Already cancelled, returns False
238
239             # Unary RPCs do not send a cancel request to the server.
240             self.assertIsNone(self._last_request)
241
242         callback.assert_not_called()
243
244     def test_reinvoke_unary_rpc(self):
245         for _ in range(3):
246             self._last_request = None
247             self._service.SomeUnary.invoke(self._request(magic_number=55),
248                                            override_pending=True)
249             self.assertEqual(self._last_request.type,
250                              packet_pb2.PacketType.REQUEST)
251
252     def test_invoke_server_streaming(self):
253         method = self._service.SomeServerStreaming.method
254
255         rep1 = method.response_type(payload='!!!')
256         rep2 = method.response_type(payload='?')
257
258         for _ in range(3):
259             self._enqueue_response(1, method, response=rep1)
260             self._enqueue_response(1, method, response=rep2)
261             self._enqueue_stream_end(1, method, Status.ABORTED)
262
263             self.assertEqual(
264                 [rep1, rep2],
265                 list(self._service.SomeServerStreaming(magic_number=4)))
266
267             self.assertEqual(
268                 4,
269                 self._sent_payload(method.request_type).magic_number)
270
271     def test_invoke_server_streaming_with_callbacks(self):
272         method = self._service.SomeServerStreaming.method
273
274         rep1 = method.response_type(payload='!!!')
275         rep2 = method.response_type(payload='?')
276
277         for _ in range(3):
278             self._enqueue_response(1, method, response=rep1)
279             self._enqueue_response(1, method, response=rep2)
280             self._enqueue_stream_end(1, method, Status.ABORTED)
281
282             callback = mock.Mock()
283             self._service.SomeServerStreaming.invoke(
284                 self._request(magic_number=3), callback, callback)
285
286             rpc = _rpc(self._service.SomeServerStreaming)
287             callback.assert_has_calls([
288                 mock.call(rpc, method.response_type(payload='!!!')),
289                 mock.call(rpc, method.response_type(payload='?')),
290                 mock.call(rpc, Status.ABORTED),
291             ])
292
293             self.assertEqual(
294                 3,
295                 self._sent_payload(method.request_type).magic_number)
296
297     def test_invoke_server_streaming_with_callback_cancel(self):
298         stub = self._service.SomeServerStreaming
299
300         resp = stub.method.response_type(payload='!!!')
301         self._enqueue_response(1, stub.method, response=resp)
302
303         callback = mock.Mock()
304         call = stub.invoke(self._request(magic_number=3), callback)
305         callback.assert_called_once_with(
306             _rpc(stub), stub.method.response_type(payload='!!!'))
307
308         callback.reset_mock()
309
310         call.cancel()
311
312         self.assertEqual(self._last_request.type,
313                          packet_pb2.PacketType.CANCEL_SERVER_STREAM)
314
315         # Ensure the RPC can be called after being cancelled.
316         self._enqueue_response(1, stub.method, response=resp)
317         self._enqueue_stream_end(1, stub.method, Status.OK)
318
319         call = stub.invoke(self._request(magic_number=3), callback, callback)
320
321         callback.assert_has_calls([
322             mock.call(_rpc(stub), stub.method.response_type(payload='!!!')),
323             mock.call(_rpc(stub), Status.OK),
324         ])
325
326     def test_ignore_bad_packets_with_pending_rpc(self):
327         method = self._service.SomeUnary.method
328         service_id = method.service.id
329
330         # Unknown channel
331         self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
332         # Bad service
333         self._enqueue_response(1,
334                                ids=(999, method.id),
335                                process_status=Status.OK)
336         # Bad method
337         self._enqueue_response(1,
338                                ids=(service_id, 999),
339                                process_status=Status.OK)
340         # For RPC not pending (is Status.OK because the packet is processed)
341         self._enqueue_response(1,
342                                ids=(service_id,
343                                     self._service.SomeBidiStreaming.method.id),
344                                process_status=Status.OK)
345
346         self._enqueue_response(1, method, process_status=Status.OK)
347
348         status, response = self._service.SomeUnary(magic_number=6)
349         self.assertIs(Status.OK, status)
350         self.assertEqual('', response.payload)
351
352     def test_pass_none_if_payload_fails_to_decode(self):
353         method = self._service.SomeUnary.method
354
355         self._enqueue_response(1,
356                                method,
357                                Status.OK,
358                                b'INVALID DATA!!!',
359                                process_status=Status.OK)
360
361         status, response = self._service.SomeUnary(magic_number=6)
362         self.assertIs(status, Status.OK)
363         self.assertIsNone(response)
364
365     def test_rpc_help_contains_method_name(self):
366         rpc = self._service.SomeUnary
367         self.assertIn(rpc.method.full_name, rpc.help())
368
369     def test_default_timeouts_set_on_impl(self):
370         impl = callback_client.Impl(None, 1.5)
371
372         self.assertEqual(impl.default_unary_timeout_s, None)
373         self.assertEqual(impl.default_stream_timeout_s, 1.5)
374
375     def test_default_timeouts_set_for_all_rpcs(self):
376         rpc_client = client.Client.from_modules(callback_client.Impl(
377             99, 100), [client.Channel(1, lambda *a, **b: None)],
378                                                 self._protos.modules())
379         rpcs = rpc_client.channel(1).rpcs
380
381         self.assertEqual(
382             rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99)
383         self.assertEqual(
384             rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
385             100)
386
387     def test_timeout_unary(self):
388         with self.assertRaises(callback_client.RpcTimeout):
389             self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
390
391     def test_timeout_unary_set_default(self):
392         self._service.SomeUnary.default_timeout_s = 0.0001
393
394         with self.assertRaises(callback_client.RpcTimeout):
395             self._service.SomeUnary()
396
397     def test_timeout_server_streaming_iteration(self):
398         responses = self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
399         with self.assertRaises(callback_client.RpcTimeout):
400             for _ in responses:
401                 pass
402
403     def test_timeout_server_streaming_responses(self):
404         responses = self._service.SomeServerStreaming()
405         with self.assertRaises(callback_client.RpcTimeout):
406             for _ in responses.responses(timeout_s=0.0001):
407                 pass
408
409
410 if __name__ == '__main__':
411     unittest.main()