Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_rpc / py / pw_rpc / callback_client.py
1 # Copyright 2021 The Pigweed Authors
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
5 # the License at
6 #
7 #     https://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
13 # the License.
14 """Defines a callback-based RPC ClientImpl to use with pw_rpc.Client.
15
16 callback_client.Impl supports invoking RPCs synchronously or asynchronously.
17 Asynchronous invocations use a callback.
18
19 Synchronous invocations look like a function call:
20
21   status, response = client.channel(1).call.MyServer.MyUnary(some_field=123)
22
23   # Streaming calls return an iterable of responses
24   for reply in client.channel(1).call.MyService.MyServerStreaming(request):
25       pass
26
27 Asynchronous invocations pass a callback in addition to the request. The
28 callback must be a callable that accepts a status and a payload, either of
29 which may be None. The Status is only set when the RPC is completed.
30
31   callback = lambda status, payload: print('Response:', status, payload)
32
33   call = client.channel(1).call.MyServer.MyUnary.invoke(
34       callback, some_field=123)
35
36   call = client.channel(1).call.MyService.MyServerStreaming.invoke(
37       callback, request):
38
39 When invoking a method, requests may be provided as a message object or as
40 kwargs for the message fields (but not both).
41 """
42
43 import enum
44 import inspect
45 import logging
46 import queue
47 import textwrap
48 import threading
49 from typing import Any, Callable, Iterator, NamedTuple, Union, Optional
50
51 from pw_protobuf_compiler.python_protos import proto_repr
52 from pw_status import Status
53
54 from pw_rpc import client, descriptors
55 from pw_rpc.client import PendingRpc, PendingRpcs
56 from pw_rpc.descriptors import Channel, Method, Service
57
58 _LOG = logging.getLogger(__name__)
59
60
61 class UseDefault(enum.Enum):
62     """Marker for args that should use a default value, when None is valid."""
63     VALUE = 0
64
65
66 OptionalTimeout = Union[UseDefault, float, None]
67
68 ResponseCallback = Callable[[PendingRpc, Any], Any]
69 CompletionCallback = Callable[[PendingRpc, Any], Any]
70 ErrorCallback = Callable[[PendingRpc, Status], Any]
71
72
73 class _Callbacks(NamedTuple):
74     response: ResponseCallback
75     completion: CompletionCallback
76     error: ErrorCallback
77
78
79 def _default_response(rpc: PendingRpc, response: Any) -> None:
80     _LOG.info('%s response: %s', rpc, response)
81
82
83 def _default_completion(rpc: PendingRpc, status: Status) -> None:
84     _LOG.info('%s finished: %s', rpc, status)
85
86
87 def _default_error(rpc: PendingRpc, status: Status) -> None:
88     _LOG.error('%s error: %s', rpc, status)
89
90
91 class _MethodClient:
92     """A method that can be invoked for a particular channel."""
93     def __init__(self, client_impl: 'Impl', rpcs: PendingRpcs,
94                  channel: Channel, method: Method,
95                  default_timeout_s: Optional[float]):
96         self._impl = client_impl
97         self._rpcs = rpcs
98         self._rpc = PendingRpc(channel, method.service, method)
99         self.default_timeout_s: Optional[float] = default_timeout_s
100
101     @property
102     def channel(self) -> Channel:
103         return self._rpc.channel
104
105     @property
106     def method(self) -> Method:
107         return self._rpc.method
108
109     @property
110     def service(self) -> Service:
111         return self._rpc.service
112
113     def invoke(self,
114                request: Any,
115                response: ResponseCallback = _default_response,
116                completion: CompletionCallback = _default_completion,
117                error: ErrorCallback = _default_error,
118                override_pending: bool = True) -> '_AsyncCall':
119         """Invokes an RPC with callbacks."""
120         self._rpcs.send_request(self._rpc,
121                                 request,
122                                 _Callbacks(response, completion, error),
123                                 override_pending=override_pending)
124         return _AsyncCall(self._rpcs, self._rpc)
125
126     def __repr__(self) -> str:
127         return self.help()
128
129     def __call__(self):
130         raise NotImplementedError('Implemented by derived classes')
131
132     def help(self) -> str:
133         """Returns a help message about this RPC."""
134         function_call = self.method.full_name + '('
135
136         docstring = inspect.getdoc(self.__call__)
137         assert docstring is not None
138
139         annotation = inspect.Signature.from_callable(self).return_annotation
140         if isinstance(annotation, type):
141             annotation = annotation.__name__
142
143         arg_sep = f',\n{" " * len(function_call)}'
144         return (
145             f'{function_call}'
146             f'{arg_sep.join(descriptors.field_help(self.method.request_type))})'
147             f'\n\n{textwrap.indent(docstring, "  ")}\n\n'
148             f'  Returns {annotation}.')
149
150
151 class RpcTimeout(Exception):
152     def __init__(self, rpc: PendingRpc, timeout: Optional[float]):
153         super().__init__(
154             f'No response received for {rpc.method} after {timeout} s')
155         self.rpc = rpc
156         self.timeout = timeout
157
158
159 class RpcError(Exception):
160     def __init__(self, rpc: PendingRpc, status: Status):
161         if status is Status.NOT_FOUND:
162             msg = ': the RPC server does not support this RPC'
163         else:
164             msg = ''
165
166         super().__init__(f'{rpc.method} failed with error {status}{msg}')
167         self.rpc = rpc
168         self.status = status
169
170
171 class _AsyncCall:
172     """Represents an ongoing callback-based call."""
173
174     # TODO(hepler): Consider alternatives (futures) and/or expand functionality.
175
176     def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc):
177         self._rpc = rpc
178         self._rpcs = rpcs
179
180     def cancel(self) -> bool:
181         return self._rpcs.send_cancel(self._rpc)
182
183     def __enter__(self) -> '_AsyncCall':
184         return self
185
186     def __exit__(self, exc_type, exc_value, traceback) -> None:
187         self.cancel()
188
189
190 class StreamingResponses:
191     """Used to iterate over a queue.SimpleQueue."""
192     def __init__(self, method_client: _MethodClient,
193                  responses: queue.SimpleQueue,
194                  default_timeout_s: OptionalTimeout):
195         self._method_client = method_client
196         self._queue = responses
197         self.status: Optional[Status] = None
198
199         if default_timeout_s is UseDefault.VALUE:
200             self.default_timeout_s = self._method_client.default_timeout_s
201         else:
202             self.default_timeout_s = default_timeout_s
203
204     @property
205     def method(self) -> Method:
206         return self._method_client.method
207
208     def responses(self,
209                   *,
210                   block: bool = True,
211                   timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator:
212         """Returns an iterator of stream responses.
213
214         Args:
215           timeout_s: timeout in seconds; None blocks indefinitely
216         """
217         if timeout_s is UseDefault.VALUE:
218             timeout_s = self.default_timeout_s
219
220         try:
221             while True:
222                 response = self._queue.get(block, timeout_s)
223
224                 if isinstance(response, Exception):
225                     raise response
226
227                 if isinstance(response, Status):
228                     self.status = response
229                     return
230
231                 yield response
232         except queue.Empty:
233             pass
234
235         raise RpcTimeout(self._method_client._rpc, timeout_s)  # pylint: disable=protected-access
236
237     def __iter__(self):
238         return self.responses()
239
240     def __repr__(self) -> str:
241         return f'{type(self).__name__}({self.method})'
242
243
244 def _method_client_docstring(method: Method) -> str:
245     return f'''\
246 Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
247
248 Calling this directly invokes the RPC synchronously. The RPC can be invoked
249 asynchronously using the invoke method.
250 '''
251
252
253 def _function_docstring(method: Method) -> str:
254     return f'''\
255 Invokes the {method.full_name} {method.type.sentence_name()} RPC.
256
257 This function accepts either the request protobuf fields as keyword arguments or
258 a request protobuf as a positional argument.
259 '''
260
261
262 def _update_function_signature(method: Method, function: Callable) -> None:
263     """Updates the name, docstring, and parameters to match a method."""
264     function.__name__ = method.full_name
265     function.__doc__ = _function_docstring(method)
266
267     # In order to have good tab completion and help messages, update the
268     # function signature to accept only keyword arguments for the proto message
269     # fields. This doesn't actually change the function signature -- it just
270     # updates how it appears when inspected.
271     sig = inspect.signature(function)
272
273     params = [next(iter(sig.parameters.values()))]  # Get the "self" parameter
274     params += method.request_parameters()
275     params.append(
276         inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY))
277     function.__signature__ = sig.replace(  # type: ignore[attr-defined]
278         parameters=params)
279
280
281 class UnaryResponse(NamedTuple):
282     """Result of invoking a unary RPC: status and response."""
283     status: Status
284     response: Any
285
286     def __repr__(self) -> str:
287         return f'({self.status}, {proto_repr(self.response)})'
288
289
290 class _UnaryResponseHandler:
291     """Tracks the state of an ongoing synchronous unary RPC call."""
292     def __init__(self, rpc: PendingRpc):
293         self._rpc = rpc
294         self._response: Any = None
295         self._status: Optional[Status] = None
296         self._error: Optional[RpcError] = None
297         self._event = threading.Event()
298
299     def on_response(self, _: PendingRpc, response: Any) -> None:
300         self._response = response
301
302     def on_completion(self, _: PendingRpc, status: Status) -> None:
303         self._status = status
304         self._event.set()
305
306     def on_error(self, _: PendingRpc, status: Status) -> None:
307         self._error = RpcError(self._rpc, status)
308         self._event.set()
309
310     def wait(self, timeout_s: Optional[float]) -> UnaryResponse:
311         if not self._event.wait(timeout_s):
312             raise RpcTimeout(self._rpc, timeout_s)
313
314         if self._error is not None:
315             raise self._error
316
317         assert self._status is not None
318         return UnaryResponse(self._status, self._response)
319
320
321 def _unary_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
322                          channel: Channel, method: Method,
323                          default_timeout: Optional[float]) -> _MethodClient:
324     """Creates an object used to call a unary method."""
325     def call(self: _MethodClient,
326              _rpc_request_proto=None,
327              *,
328              pw_rpc_timeout_s=UseDefault.VALUE,
329              **request_fields) -> UnaryResponse:
330
331         handler = _UnaryResponseHandler(self._rpc)  # pylint: disable=protected-access
332         self.invoke(
333             self.method.get_request(_rpc_request_proto, request_fields),
334             handler.on_response, handler.on_completion, handler.on_error)
335
336         if pw_rpc_timeout_s is UseDefault.VALUE:
337             pw_rpc_timeout_s = self.default_timeout_s
338
339         return handler.wait(pw_rpc_timeout_s)
340
341     _update_function_signature(method, call)
342
343     # The MethodClient class is created dynamically so that the __call__ method
344     # can be configured differently for each method.
345     method_client_type = type(
346         f'{method.name}_UnaryMethodClient', (_MethodClient, ),
347         dict(__call__=call, __doc__=_method_client_docstring(method)))
348     return method_client_type(client_impl, rpcs, channel, method,
349                               default_timeout)
350
351
352 def _server_streaming_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
353                                     channel: Channel, method: Method,
354                                     default_timeout: Optional[float]):
355     """Creates an object used to call a server streaming method."""
356     def call(self: _MethodClient,
357              _rpc_request_proto=None,
358              *,
359              pw_rpc_timeout_s=UseDefault.VALUE,
360              **request_fields) -> StreamingResponses:
361         responses: queue.SimpleQueue = queue.SimpleQueue()
362         self.invoke(
363             self.method.get_request(_rpc_request_proto, request_fields),
364             lambda _, response: responses.put(response),
365             lambda _, status: responses.put(status),
366             lambda rpc, status: responses.put(RpcError(rpc, status)))
367         return StreamingResponses(self, responses, pw_rpc_timeout_s)
368
369     _update_function_signature(method, call)
370
371     # The MethodClient class is created dynamically so that the __call__ method
372     # can be configured differently for each method type.
373     method_client_type = type(
374         f'{method.name}_ServerStreamingMethodClient', (_MethodClient, ),
375         dict(__call__=call, __doc__=_method_client_docstring(method)))
376     return method_client_type(client_impl, rpcs, channel, method,
377                               default_timeout)
378
379
380 class ClientStreamingMethodClient(_MethodClient):
381     def __call__(self):
382         raise NotImplementedError
383
384     def invoke(self,
385                request: Any,
386                response: ResponseCallback = _default_response,
387                completion: CompletionCallback = _default_completion,
388                error: ErrorCallback = _default_error,
389                override_pending: bool = True) -> _AsyncCall:
390         raise NotImplementedError
391
392
393 class BidirectionalStreamingMethodClient(_MethodClient):
394     def __call__(self):
395         raise NotImplementedError
396
397     def invoke(self,
398                request: Any,
399                response: ResponseCallback = _default_response,
400                completion: CompletionCallback = _default_completion,
401                error: ErrorCallback = _default_error,
402                override_pending: bool = True) -> _AsyncCall:
403         raise NotImplementedError
404
405
406 class Impl(client.ClientImpl):
407     """Callback-based ClientImpl."""
408     def __init__(self,
409                  default_unary_timeout_s: Optional[float] = 1.0,
410                  default_stream_timeout_s: Optional[float] = 1.0):
411         super().__init__()
412         self._default_unary_timeout_s = default_unary_timeout_s
413         self._default_stream_timeout_s = default_stream_timeout_s
414
415     @property
416     def default_unary_timeout_s(self) -> Optional[float]:
417         return self._default_unary_timeout_s
418
419     @property
420     def default_stream_timeout_s(self) -> Optional[float]:
421         return self._default_stream_timeout_s
422
423     def method_client(self, channel: Channel, method: Method) -> _MethodClient:
424         """Returns an object that invokes a method using the given chanel."""
425
426         if method.type is Method.Type.UNARY:
427             return _unary_method_client(self, self.rpcs, channel, method,
428                                         self.default_unary_timeout_s)
429
430         if method.type is Method.Type.SERVER_STREAMING:
431             return _server_streaming_method_client(
432                 self, self.rpcs, channel, method,
433                 self.default_stream_timeout_s)
434
435         if method.type is Method.Type.CLIENT_STREAMING:
436             return ClientStreamingMethodClient(self, self.rpcs, channel,
437                                                method,
438                                                self.default_unary_timeout_s)
439
440         if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
441             return BidirectionalStreamingMethodClient(
442                 self, self.rpcs, channel, method,
443                 self.default_stream_timeout_s)
444
445         raise AssertionError(f'Unknown method type {method.type}')
446
447     def handle_response(self,
448                         rpc: PendingRpc,
449                         context,
450                         payload,
451                         *,
452                         args: tuple = (),
453                         kwargs: dict = None) -> None:
454         """Invokes the callback associated with this RPC.
455
456         Any additional positional and keyword args passed through
457         Client.process_packet are forwarded to the callback.
458         """
459         if kwargs is None:
460             kwargs = {}
461
462         try:
463             context.response(rpc, payload, *args, **kwargs)
464         except:  # pylint: disable=bare-except
465             self.rpcs.send_cancel(rpc)
466             _LOG.exception('Response callback %s for %s raised exception',
467                            context.response, rpc)
468
469     def handle_completion(self,
470                           rpc: PendingRpc,
471                           context,
472                           status: Status,
473                           *,
474                           args: tuple = (),
475                           kwargs: dict = None):
476         if kwargs is None:
477             kwargs = {}
478
479         try:
480             context.completion(rpc, status, *args, **kwargs)
481         except:  # pylint: disable=bare-except
482             _LOG.exception('Completion callback %s for %s raised exception',
483                            context.completion, rpc)
484
485     def handle_error(self,
486                      rpc: PendingRpc,
487                      context,
488                      status: Status,
489                      *,
490                      args: tuple = (),
491                      kwargs: dict = None) -> None:
492         if kwargs is None:
493             kwargs = {}
494
495         try:
496             context.error(rpc, status, *args, **kwargs)
497         except:  # pylint: disable=bare-except
498             _LOG.exception('Error callback %s for %s raised exception',
499                            context.error, rpc)