Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / src / python / grpcio / grpc / aio / _channel.py
1 # Copyright 2019 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Invocation-side implementation of gRPC Asyncio Python."""
15
16 import asyncio
17 import sys
18 from typing import Any, Iterable, Optional, Sequence, List
19
20 import grpc
21 from grpc import _common, _compression, _grpcio_metadata
22 from grpc._cython import cygrpc
23
24 from . import _base_call, _base_channel
25 from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
26                     UnaryUnaryCall)
27 from ._interceptor import (
28     InterceptedUnaryUnaryCall, InterceptedUnaryStreamCall,
29     InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor,
30     UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
31     StreamUnaryClientInterceptor, StreamStreamClientInterceptor)
32 from ._metadata import Metadata
33 from ._typing import (ChannelArgumentType, DeserializingFunction,
34                       SerializingFunction, RequestIterableType)
35 from ._utils import _timeout_to_deadline
36
37 _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
38
39 if sys.version_info[1] < 7:
40
41     def _all_tasks() -> Iterable[asyncio.Task]:
42         return asyncio.Task.all_tasks()
43 else:
44
45     def _all_tasks() -> Iterable[asyncio.Task]:
46         return asyncio.all_tasks()
47
48
49 def _augment_channel_arguments(base_options: ChannelArgumentType,
50                                compression: Optional[grpc.Compression]):
51     compression_channel_argument = _compression.create_channel_option(
52         compression)
53     user_agent_channel_argument = ((
54         cygrpc.ChannelArgKey.primary_user_agent_string,
55         _USER_AGENT,
56     ),)
57     return tuple(base_options
58                 ) + compression_channel_argument + user_agent_channel_argument
59
60
61 class _BaseMultiCallable:
62     """Base class of all multi callable objects.
63
64     Handles the initialization logic and stores common attributes.
65     """
66     _loop: asyncio.AbstractEventLoop
67     _channel: cygrpc.AioChannel
68     _method: bytes
69     _request_serializer: SerializingFunction
70     _response_deserializer: DeserializingFunction
71     _interceptors: Optional[Sequence[ClientInterceptor]]
72     _loop: asyncio.AbstractEventLoop
73
74     # pylint: disable=too-many-arguments
75     def __init__(
76         self,
77         channel: cygrpc.AioChannel,
78         method: bytes,
79         request_serializer: SerializingFunction,
80         response_deserializer: DeserializingFunction,
81         interceptors: Optional[Sequence[ClientInterceptor]],
82         loop: asyncio.AbstractEventLoop,
83     ) -> None:
84         self._loop = loop
85         self._channel = channel
86         self._method = method
87         self._request_serializer = request_serializer
88         self._response_deserializer = response_deserializer
89         self._interceptors = interceptors
90
91     @staticmethod
92     def _init_metadata(
93             metadata: Optional[Metadata] = None,
94             compression: Optional[grpc.Compression] = None) -> Metadata:
95         """Based on the provided values for <metadata> or <compression> initialise the final
96         metadata, as it should be used for the current call.
97         """
98         metadata = metadata or Metadata()
99         if compression:
100             metadata = Metadata(
101                 *_compression.augment_metadata(metadata, compression))
102         return metadata
103
104
105 class UnaryUnaryMultiCallable(_BaseMultiCallable,
106                               _base_channel.UnaryUnaryMultiCallable):
107
108     def __call__(
109         self,
110         request: Any,
111         *,
112         timeout: Optional[float] = None,
113         metadata: Optional[Metadata] = None,
114         credentials: Optional[grpc.CallCredentials] = None,
115         wait_for_ready: Optional[bool] = None,
116         compression: Optional[grpc.Compression] = None
117     ) -> _base_call.UnaryUnaryCall:
118
119         metadata = self._init_metadata(metadata, compression)
120         if not self._interceptors:
121             call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
122                                   metadata, credentials, wait_for_ready,
123                                   self._channel, self._method,
124                                   self._request_serializer,
125                                   self._response_deserializer, self._loop)
126         else:
127             call = InterceptedUnaryUnaryCall(
128                 self._interceptors, request, timeout, metadata, credentials,
129                 wait_for_ready, self._channel, self._method,
130                 self._request_serializer, self._response_deserializer,
131                 self._loop)
132
133         return call
134
135
136 class UnaryStreamMultiCallable(_BaseMultiCallable,
137                                _base_channel.UnaryStreamMultiCallable):
138
139     def __call__(
140         self,
141         request: Any,
142         *,
143         timeout: Optional[float] = None,
144         metadata: Optional[Metadata] = None,
145         credentials: Optional[grpc.CallCredentials] = None,
146         wait_for_ready: Optional[bool] = None,
147         compression: Optional[grpc.Compression] = None
148     ) -> _base_call.UnaryStreamCall:
149
150         metadata = self._init_metadata(metadata, compression)
151         deadline = _timeout_to_deadline(timeout)
152
153         if not self._interceptors:
154             call = UnaryStreamCall(request, deadline, metadata, credentials,
155                                    wait_for_ready, self._channel, self._method,
156                                    self._request_serializer,
157                                    self._response_deserializer, self._loop)
158         else:
159             call = InterceptedUnaryStreamCall(
160                 self._interceptors, request, deadline, metadata, credentials,
161                 wait_for_ready, self._channel, self._method,
162                 self._request_serializer, self._response_deserializer,
163                 self._loop)
164
165         return call
166
167
168 class StreamUnaryMultiCallable(_BaseMultiCallable,
169                                _base_channel.StreamUnaryMultiCallable):
170
171     def __call__(
172         self,
173         request_iterator: Optional[RequestIterableType] = None,
174         timeout: Optional[float] = None,
175         metadata: Optional[Metadata] = None,
176         credentials: Optional[grpc.CallCredentials] = None,
177         wait_for_ready: Optional[bool] = None,
178         compression: Optional[grpc.Compression] = None
179     ) -> _base_call.StreamUnaryCall:
180
181         metadata = self._init_metadata(metadata, compression)
182         deadline = _timeout_to_deadline(timeout)
183
184         if not self._interceptors:
185             call = StreamUnaryCall(request_iterator, deadline, metadata,
186                                    credentials, wait_for_ready, self._channel,
187                                    self._method, self._request_serializer,
188                                    self._response_deserializer, self._loop)
189         else:
190             call = InterceptedStreamUnaryCall(
191                 self._interceptors, request_iterator, deadline, metadata,
192                 credentials, wait_for_ready, self._channel, self._method,
193                 self._request_serializer, self._response_deserializer,
194                 self._loop)
195
196         return call
197
198
199 class StreamStreamMultiCallable(_BaseMultiCallable,
200                                 _base_channel.StreamStreamMultiCallable):
201
202     def __call__(
203         self,
204         request_iterator: Optional[RequestIterableType] = None,
205         timeout: Optional[float] = None,
206         metadata: Optional[Metadata] = None,
207         credentials: Optional[grpc.CallCredentials] = None,
208         wait_for_ready: Optional[bool] = None,
209         compression: Optional[grpc.Compression] = None
210     ) -> _base_call.StreamStreamCall:
211
212         metadata = self._init_metadata(metadata, compression)
213         deadline = _timeout_to_deadline(timeout)
214
215         if not self._interceptors:
216             call = StreamStreamCall(request_iterator, deadline, metadata,
217                                     credentials, wait_for_ready, self._channel,
218                                     self._method, self._request_serializer,
219                                     self._response_deserializer, self._loop)
220         else:
221             call = InterceptedStreamStreamCall(
222                 self._interceptors, request_iterator, deadline, metadata,
223                 credentials, wait_for_ready, self._channel, self._method,
224                 self._request_serializer, self._response_deserializer,
225                 self._loop)
226
227         return call
228
229
230 class Channel(_base_channel.Channel):
231     _loop: asyncio.AbstractEventLoop
232     _channel: cygrpc.AioChannel
233     _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
234     _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
235     _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
236     _stream_stream_interceptors: List[StreamStreamClientInterceptor]
237
238     def __init__(self, target: str, options: ChannelArgumentType,
239                  credentials: Optional[grpc.ChannelCredentials],
240                  compression: Optional[grpc.Compression],
241                  interceptors: Optional[Sequence[ClientInterceptor]]):
242         """Constructor.
243
244         Args:
245           target: The target to which to connect.
246           options: Configuration options for the channel.
247           credentials: A cygrpc.ChannelCredentials or None.
248           compression: An optional value indicating the compression method to be
249             used over the lifetime of the channel.
250           interceptors: An optional list of interceptors that would be used for
251             intercepting any RPC executed with that channel.
252         """
253         self._unary_unary_interceptors = []
254         self._unary_stream_interceptors = []
255         self._stream_unary_interceptors = []
256         self._stream_stream_interceptors = []
257
258         if interceptors is not None:
259             for interceptor in interceptors:
260                 if isinstance(interceptor, UnaryUnaryClientInterceptor):
261                     self._unary_unary_interceptors.append(interceptor)
262                 elif isinstance(interceptor, UnaryStreamClientInterceptor):
263                     self._unary_stream_interceptors.append(interceptor)
264                 elif isinstance(interceptor, StreamUnaryClientInterceptor):
265                     self._stream_unary_interceptors.append(interceptor)
266                 elif isinstance(interceptor, StreamStreamClientInterceptor):
267                     self._stream_stream_interceptors.append(interceptor)
268                 else:
269                     raise ValueError(
270                         "Interceptor {} must be ".format(interceptor) +
271                         "{} or ".format(UnaryUnaryClientInterceptor.__name__) +
272                         "{} or ".format(UnaryStreamClientInterceptor.__name__) +
273                         "{} or ".format(StreamUnaryClientInterceptor.__name__) +
274                         "{}. ".format(StreamStreamClientInterceptor.__name__))
275
276         self._loop = cygrpc.get_working_loop()
277         self._channel = cygrpc.AioChannel(
278             _common.encode(target),
279             _augment_channel_arguments(options, compression), credentials,
280             self._loop)
281
282     async def __aenter__(self):
283         return self
284
285     async def __aexit__(self, exc_type, exc_val, exc_tb):
286         await self._close(None)
287
288     async def _close(self, grace):  # pylint: disable=too-many-branches
289         if self._channel.closed():
290             return
291
292         # No new calls will be accepted by the Cython channel.
293         self._channel.closing()
294
295         # Iterate through running tasks
296         tasks = _all_tasks()
297         calls = []
298         call_tasks = []
299         for task in tasks:
300             try:
301                 stack = task.get_stack(limit=1)
302             except AttributeError as attribute_error:
303                 # NOTE(lidiz) tl;dr: If the Task is created with a CPython
304                 # object, it will trigger AttributeError.
305                 #
306                 # In the global finalizer, the event loop schedules
307                 # a CPython PyAsyncGenAThrow object.
308                 # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
309                 #
310                 # However, the PyAsyncGenAThrow object is written in C and
311                 # failed to include the normal Python frame objects. Hence,
312                 # this exception is a false negative, and it is safe to ignore
313                 # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
314                 # but not available until 3.9 or 3.8.3. So, we have to keep it
315                 # for a while.
316                 # TODO(lidiz) drop this hack after 3.8 deprecation
317                 if 'frame' in str(attribute_error):
318                     continue
319                 else:
320                     raise
321
322             # If the Task is created by a C-extension, the stack will be empty.
323             if not stack:
324                 continue
325
326             # Locate ones created by `aio.Call`.
327             frame = stack[0]
328             candidate = frame.f_locals.get('self')
329             if candidate:
330                 if isinstance(candidate, _base_call.Call):
331                     if hasattr(candidate, '_channel'):
332                         # For intercepted Call object
333                         if candidate._channel is not self._channel:
334                             continue
335                     elif hasattr(candidate, '_cython_call'):
336                         # For normal Call object
337                         if candidate._cython_call._channel is not self._channel:
338                             continue
339                     else:
340                         # Unidentified Call object
341                         raise cygrpc.InternalError(
342                             f'Unrecognized call object: {candidate}')
343
344                     calls.append(candidate)
345                     call_tasks.append(task)
346
347         # If needed, try to wait for them to finish.
348         # Call objects are not always awaitables.
349         if grace and call_tasks:
350             await asyncio.wait(call_tasks, timeout=grace, loop=self._loop)
351
352         # Time to cancel existing calls.
353         for call in calls:
354             call.cancel()
355
356         # Destroy the channel
357         self._channel.close()
358
359     async def close(self, grace: Optional[float] = None):
360         await self._close(grace)
361
362     def get_state(self,
363                   try_to_connect: bool = False) -> grpc.ChannelConnectivity:
364         result = self._channel.check_connectivity_state(try_to_connect)
365         return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
366
367     async def wait_for_state_change(
368         self,
369         last_observed_state: grpc.ChannelConnectivity,
370     ) -> None:
371         assert await self._channel.watch_connectivity_state(
372             last_observed_state.value[0], None)
373
374     async def channel_ready(self) -> None:
375         state = self.get_state(try_to_connect=True)
376         while state != grpc.ChannelConnectivity.READY:
377             await self.wait_for_state_change(state)
378             state = self.get_state(try_to_connect=True)
379
380     def unary_unary(
381         self,
382         method: str,
383         request_serializer: Optional[SerializingFunction] = None,
384         response_deserializer: Optional[DeserializingFunction] = None
385     ) -> UnaryUnaryMultiCallable:
386         return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
387                                        request_serializer,
388                                        response_deserializer,
389                                        self._unary_unary_interceptors,
390                                        self._loop)
391
392     def unary_stream(
393         self,
394         method: str,
395         request_serializer: Optional[SerializingFunction] = None,
396         response_deserializer: Optional[DeserializingFunction] = None
397     ) -> UnaryStreamMultiCallable:
398         return UnaryStreamMultiCallable(self._channel, _common.encode(method),
399                                         request_serializer,
400                                         response_deserializer,
401                                         self._unary_stream_interceptors,
402                                         self._loop)
403
404     def stream_unary(
405         self,
406         method: str,
407         request_serializer: Optional[SerializingFunction] = None,
408         response_deserializer: Optional[DeserializingFunction] = None
409     ) -> StreamUnaryMultiCallable:
410         return StreamUnaryMultiCallable(self._channel, _common.encode(method),
411                                         request_serializer,
412                                         response_deserializer,
413                                         self._stream_unary_interceptors,
414                                         self._loop)
415
416     def stream_stream(
417         self,
418         method: str,
419         request_serializer: Optional[SerializingFunction] = None,
420         response_deserializer: Optional[DeserializingFunction] = None
421     ) -> StreamStreamMultiCallable:
422         return StreamStreamMultiCallable(self._channel, _common.encode(method),
423                                          request_serializer,
424                                          response_deserializer,
425                                          self._stream_stream_interceptors,
426                                          self._loop)
427
428
429 def insecure_channel(
430         target: str,
431         options: Optional[ChannelArgumentType] = None,
432         compression: Optional[grpc.Compression] = None,
433         interceptors: Optional[Sequence[ClientInterceptor]] = None):
434     """Creates an insecure asynchronous Channel to a server.
435
436     Args:
437       target: The server address
438       options: An optional list of key-value pairs (:term:`channel_arguments`
439         in gRPC Core runtime) to configure the channel.
440       compression: An optional value indicating the compression method to be
441         used over the lifetime of the channel. This is an EXPERIMENTAL option.
442       interceptors: An optional sequence of interceptors that will be executed for
443         any call executed with this channel.
444
445     Returns:
446       A Channel.
447     """
448     return Channel(target, () if options is None else options, None,
449                    compression, interceptors)
450
451
452 def secure_channel(target: str,
453                    credentials: grpc.ChannelCredentials,
454                    options: Optional[ChannelArgumentType] = None,
455                    compression: Optional[grpc.Compression] = None,
456                    interceptors: Optional[Sequence[ClientInterceptor]] = None):
457     """Creates a secure asynchronous Channel to a server.
458
459     Args:
460       target: The server address.
461       credentials: A ChannelCredentials instance.
462       options: An optional list of key-value pairs (:term:`channel_arguments`
463         in gRPC Core runtime) to configure the channel.
464       compression: An optional value indicating the compression method to be
465         used over the lifetime of the channel. This is an EXPERIMENTAL option.
466       interceptors: An optional sequence of interceptors that will be executed for
467         any call executed with this channel.
468
469     Returns:
470       An aio.Channel.
471     """
472     return Channel(target, () if options is None else options,
473                    credentials._credentials, compression, interceptors)