1 # Copyright 2021 The Pigweed Authors
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
7 # https://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
14 """Defines a callback-based RPC ClientImpl to use with pw_rpc.Client.
16 callback_client.Impl supports invoking RPCs synchronously or asynchronously.
17 Asynchronous invocations use a callback.
19 Synchronous invocations look like a function call:
21 status, response = client.channel(1).call.MyServer.MyUnary(some_field=123)
23 # Streaming calls return an iterable of responses
24 for reply in client.channel(1).call.MyService.MyServerStreaming(request):
27 Asynchronous invocations pass a callback in addition to the request. The
28 callback must be a callable that accepts a status and a payload, either of
29 which may be None. The Status is only set when the RPC is completed.
31 callback = lambda status, payload: print('Response:', status, payload)
33 call = client.channel(1).call.MyServer.MyUnary.invoke(
34 callback, some_field=123)
36 call = client.channel(1).call.MyService.MyServerStreaming.invoke(
39 When invoking a method, requests may be provided as a message object or as
40 kwargs for the message fields (but not both).
49 from typing import Any, Callable, Iterator, NamedTuple, Union, Optional
51 from pw_protobuf_compiler.python_protos import proto_repr
52 from pw_status import Status
54 from pw_rpc import client, descriptors
55 from pw_rpc.client import PendingRpc, PendingRpcs
56 from pw_rpc.descriptors import Channel, Method, Service
58 _LOG = logging.getLogger(__name__)
61 class UseDefault(enum.Enum):
62 """Marker for args that should use a default value, when None is valid."""
66 OptionalTimeout = Union[UseDefault, float, None]
68 ResponseCallback = Callable[[PendingRpc, Any], Any]
69 CompletionCallback = Callable[[PendingRpc, Any], Any]
70 ErrorCallback = Callable[[PendingRpc, Status], Any]
73 class _Callbacks(NamedTuple):
74 response: ResponseCallback
75 completion: CompletionCallback
79 def _default_response(rpc: PendingRpc, response: Any) -> None:
80 _LOG.info('%s response: %s', rpc, response)
83 def _default_completion(rpc: PendingRpc, status: Status) -> None:
84 _LOG.info('%s finished: %s', rpc, status)
87 def _default_error(rpc: PendingRpc, status: Status) -> None:
88 _LOG.error('%s error: %s', rpc, status)
92 """A method that can be invoked for a particular channel."""
93 def __init__(self, client_impl: 'Impl', rpcs: PendingRpcs,
94 channel: Channel, method: Method,
95 default_timeout_s: Optional[float]):
96 self._impl = client_impl
98 self._rpc = PendingRpc(channel, method.service, method)
99 self.default_timeout_s: Optional[float] = default_timeout_s
102 def channel(self) -> Channel:
103 return self._rpc.channel
106 def method(self) -> Method:
107 return self._rpc.method
110 def service(self) -> Service:
111 return self._rpc.service
115 response: ResponseCallback = _default_response,
116 completion: CompletionCallback = _default_completion,
117 error: ErrorCallback = _default_error,
118 override_pending: bool = True) -> '_AsyncCall':
119 """Invokes an RPC with callbacks."""
120 self._rpcs.send_request(self._rpc,
122 _Callbacks(response, completion, error),
123 override_pending=override_pending)
124 return _AsyncCall(self._rpcs, self._rpc)
126 def __repr__(self) -> str:
130 raise NotImplementedError('Implemented by derived classes')
132 def help(self) -> str:
133 """Returns a help message about this RPC."""
134 function_call = self.method.full_name + '('
136 docstring = inspect.getdoc(self.__call__)
137 assert docstring is not None
139 annotation = inspect.Signature.from_callable(self).return_annotation
140 if isinstance(annotation, type):
141 annotation = annotation.__name__
143 arg_sep = f',\n{" " * len(function_call)}'
146 f'{arg_sep.join(descriptors.field_help(self.method.request_type))})'
147 f'\n\n{textwrap.indent(docstring, " ")}\n\n'
148 f' Returns {annotation}.')
151 class RpcTimeout(Exception):
152 def __init__(self, rpc: PendingRpc, timeout: Optional[float]):
154 f'No response received for {rpc.method} after {timeout} s')
156 self.timeout = timeout
159 class RpcError(Exception):
160 def __init__(self, rpc: PendingRpc, status: Status):
161 if status is Status.NOT_FOUND:
162 msg = ': the RPC server does not support this RPC'
166 super().__init__(f'{rpc.method} failed with error {status}{msg}')
172 """Represents an ongoing callback-based call."""
174 # TODO(hepler): Consider alternatives (futures) and/or expand functionality.
176 def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc):
180 def cancel(self) -> bool:
181 return self._rpcs.send_cancel(self._rpc)
183 def __enter__(self) -> '_AsyncCall':
186 def __exit__(self, exc_type, exc_value, traceback) -> None:
190 class StreamingResponses:
191 """Used to iterate over a queue.SimpleQueue."""
192 def __init__(self, method_client: _MethodClient,
193 responses: queue.SimpleQueue,
194 default_timeout_s: OptionalTimeout):
195 self._method_client = method_client
196 self._queue = responses
197 self.status: Optional[Status] = None
199 if default_timeout_s is UseDefault.VALUE:
200 self.default_timeout_s = self._method_client.default_timeout_s
202 self.default_timeout_s = default_timeout_s
205 def method(self) -> Method:
206 return self._method_client.method
211 timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator:
212 """Returns an iterator of stream responses.
215 timeout_s: timeout in seconds; None blocks indefinitely
217 if timeout_s is UseDefault.VALUE:
218 timeout_s = self.default_timeout_s
222 response = self._queue.get(block, timeout_s)
224 if isinstance(response, Exception):
227 if isinstance(response, Status):
228 self.status = response
235 raise RpcTimeout(self._method_client._rpc, timeout_s) # pylint: disable=protected-access
238 return self.responses()
240 def __repr__(self) -> str:
241 return f'{type(self).__name__}({self.method})'
244 def _method_client_docstring(method: Method) -> str:
246 Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
248 Calling this directly invokes the RPC synchronously. The RPC can be invoked
249 asynchronously using the invoke method.
253 def _function_docstring(method: Method) -> str:
255 Invokes the {method.full_name} {method.type.sentence_name()} RPC.
257 This function accepts either the request protobuf fields as keyword arguments or
258 a request protobuf as a positional argument.
262 def _update_function_signature(method: Method, function: Callable) -> None:
263 """Updates the name, docstring, and parameters to match a method."""
264 function.__name__ = method.full_name
265 function.__doc__ = _function_docstring(method)
267 # In order to have good tab completion and help messages, update the
268 # function signature to accept only keyword arguments for the proto message
269 # fields. This doesn't actually change the function signature -- it just
270 # updates how it appears when inspected.
271 sig = inspect.signature(function)
273 params = [next(iter(sig.parameters.values()))] # Get the "self" parameter
274 params += method.request_parameters()
276 inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY))
277 function.__signature__ = sig.replace( # type: ignore[attr-defined]
281 class UnaryResponse(NamedTuple):
282 """Result of invoking a unary RPC: status and response."""
286 def __repr__(self) -> str:
287 return f'({self.status}, {proto_repr(self.response)})'
290 class _UnaryResponseHandler:
291 """Tracks the state of an ongoing synchronous unary RPC call."""
292 def __init__(self, rpc: PendingRpc):
294 self._response: Any = None
295 self._status: Optional[Status] = None
296 self._error: Optional[RpcError] = None
297 self._event = threading.Event()
299 def on_response(self, _: PendingRpc, response: Any) -> None:
300 self._response = response
302 def on_completion(self, _: PendingRpc, status: Status) -> None:
303 self._status = status
306 def on_error(self, _: PendingRpc, status: Status) -> None:
307 self._error = RpcError(self._rpc, status)
310 def wait(self, timeout_s: Optional[float]) -> UnaryResponse:
311 if not self._event.wait(timeout_s):
312 raise RpcTimeout(self._rpc, timeout_s)
314 if self._error is not None:
317 assert self._status is not None
318 return UnaryResponse(self._status, self._response)
321 def _unary_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
322 channel: Channel, method: Method,
323 default_timeout: Optional[float]) -> _MethodClient:
324 """Creates an object used to call a unary method."""
325 def call(self: _MethodClient,
326 _rpc_request_proto=None,
328 pw_rpc_timeout_s=UseDefault.VALUE,
329 **request_fields) -> UnaryResponse:
331 handler = _UnaryResponseHandler(self._rpc) # pylint: disable=protected-access
333 self.method.get_request(_rpc_request_proto, request_fields),
334 handler.on_response, handler.on_completion, handler.on_error)
336 if pw_rpc_timeout_s is UseDefault.VALUE:
337 pw_rpc_timeout_s = self.default_timeout_s
339 return handler.wait(pw_rpc_timeout_s)
341 _update_function_signature(method, call)
343 # The MethodClient class is created dynamically so that the __call__ method
344 # can be configured differently for each method.
345 method_client_type = type(
346 f'{method.name}_UnaryMethodClient', (_MethodClient, ),
347 dict(__call__=call, __doc__=_method_client_docstring(method)))
348 return method_client_type(client_impl, rpcs, channel, method,
352 def _server_streaming_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
353 channel: Channel, method: Method,
354 default_timeout: Optional[float]):
355 """Creates an object used to call a server streaming method."""
356 def call(self: _MethodClient,
357 _rpc_request_proto=None,
359 pw_rpc_timeout_s=UseDefault.VALUE,
360 **request_fields) -> StreamingResponses:
361 responses: queue.SimpleQueue = queue.SimpleQueue()
363 self.method.get_request(_rpc_request_proto, request_fields),
364 lambda _, response: responses.put(response),
365 lambda _, status: responses.put(status),
366 lambda rpc, status: responses.put(RpcError(rpc, status)))
367 return StreamingResponses(self, responses, pw_rpc_timeout_s)
369 _update_function_signature(method, call)
371 # The MethodClient class is created dynamically so that the __call__ method
372 # can be configured differently for each method type.
373 method_client_type = type(
374 f'{method.name}_ServerStreamingMethodClient', (_MethodClient, ),
375 dict(__call__=call, __doc__=_method_client_docstring(method)))
376 return method_client_type(client_impl, rpcs, channel, method,
380 class ClientStreamingMethodClient(_MethodClient):
382 raise NotImplementedError
386 response: ResponseCallback = _default_response,
387 completion: CompletionCallback = _default_completion,
388 error: ErrorCallback = _default_error,
389 override_pending: bool = True) -> _AsyncCall:
390 raise NotImplementedError
393 class BidirectionalStreamingMethodClient(_MethodClient):
395 raise NotImplementedError
399 response: ResponseCallback = _default_response,
400 completion: CompletionCallback = _default_completion,
401 error: ErrorCallback = _default_error,
402 override_pending: bool = True) -> _AsyncCall:
403 raise NotImplementedError
406 class Impl(client.ClientImpl):
407 """Callback-based ClientImpl."""
409 default_unary_timeout_s: Optional[float] = 1.0,
410 default_stream_timeout_s: Optional[float] = 1.0):
412 self._default_unary_timeout_s = default_unary_timeout_s
413 self._default_stream_timeout_s = default_stream_timeout_s
416 def default_unary_timeout_s(self) -> Optional[float]:
417 return self._default_unary_timeout_s
420 def default_stream_timeout_s(self) -> Optional[float]:
421 return self._default_stream_timeout_s
423 def method_client(self, channel: Channel, method: Method) -> _MethodClient:
424 """Returns an object that invokes a method using the given chanel."""
426 if method.type is Method.Type.UNARY:
427 return _unary_method_client(self, self.rpcs, channel, method,
428 self.default_unary_timeout_s)
430 if method.type is Method.Type.SERVER_STREAMING:
431 return _server_streaming_method_client(
432 self, self.rpcs, channel, method,
433 self.default_stream_timeout_s)
435 if method.type is Method.Type.CLIENT_STREAMING:
436 return ClientStreamingMethodClient(self, self.rpcs, channel,
438 self.default_unary_timeout_s)
440 if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
441 return BidirectionalStreamingMethodClient(
442 self, self.rpcs, channel, method,
443 self.default_stream_timeout_s)
445 raise AssertionError(f'Unknown method type {method.type}')
447 def handle_response(self,
453 kwargs: dict = None) -> None:
454 """Invokes the callback associated with this RPC.
456 Any additional positional and keyword args passed through
457 Client.process_packet are forwarded to the callback.
463 context.response(rpc, payload, *args, **kwargs)
464 except: # pylint: disable=bare-except
465 self.rpcs.send_cancel(rpc)
466 _LOG.exception('Response callback %s for %s raised exception',
467 context.response, rpc)
469 def handle_completion(self,
475 kwargs: dict = None):
480 context.completion(rpc, status, *args, **kwargs)
481 except: # pylint: disable=bare-except
482 _LOG.exception('Completion callback %s for %s raised exception',
483 context.completion, rpc)
485 def handle_error(self,
491 kwargs: dict = None) -> None:
496 context.error(rpc, status, *args, **kwargs)
497 except: # pylint: disable=bare-except
498 _LOG.exception('Error callback %s for %s raised exception',