Imported Upstream version 1.27.0
[platform/upstream/grpc.git] / src / python / grpcio / grpc / _interceptor.py
1 # Copyright 2017 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Implementation of gRPC Python interceptors."""
15
16 import collections
17 import sys
18
19 import grpc
20
21
22 class _ServicePipeline(object):
23
24     def __init__(self, interceptors):
25         self.interceptors = tuple(interceptors)
26
27     def _continuation(self, thunk, index):
28         return lambda context: self._intercept_at(thunk, index, context)
29
30     def _intercept_at(self, thunk, index, context):
31         if index < len(self.interceptors):
32             interceptor = self.interceptors[index]
33             thunk = self._continuation(thunk, index + 1)
34             return interceptor.intercept_service(thunk, context)
35         else:
36             return thunk(context)
37
38     def execute(self, thunk, context):
39         return self._intercept_at(thunk, 0, context)
40
41
42 def service_pipeline(interceptors):
43     return _ServicePipeline(interceptors) if interceptors else None
44
45
46 class _ClientCallDetails(
47         collections.namedtuple('_ClientCallDetails',
48                                ('method', 'timeout', 'metadata', 'credentials',
49                                 'wait_for_ready', 'compression')),
50         grpc.ClientCallDetails):
51     pass
52
53
54 def _unwrap_client_call_details(call_details, default_details):
55     try:
56         method = call_details.method
57     except AttributeError:
58         method = default_details.method
59
60     try:
61         timeout = call_details.timeout
62     except AttributeError:
63         timeout = default_details.timeout
64
65     try:
66         metadata = call_details.metadata
67     except AttributeError:
68         metadata = default_details.metadata
69
70     try:
71         credentials = call_details.credentials
72     except AttributeError:
73         credentials = default_details.credentials
74
75     try:
76         wait_for_ready = call_details.wait_for_ready
77     except AttributeError:
78         wait_for_ready = default_details.wait_for_ready
79
80     try:
81         compression = call_details.compression
82     except AttributeError:
83         compression = default_details.compression
84
85     return method, timeout, metadata, credentials, wait_for_ready, compression
86
87
88 class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call):  # pylint: disable=too-many-ancestors
89
90     def __init__(self, exception, traceback):
91         super(_FailureOutcome, self).__init__()
92         self._exception = exception
93         self._traceback = traceback
94
95     def initial_metadata(self):
96         return None
97
98     def trailing_metadata(self):
99         return None
100
101     def code(self):
102         return grpc.StatusCode.INTERNAL
103
104     def details(self):
105         return 'Exception raised while intercepting the RPC'
106
107     def cancel(self):
108         return False
109
110     def cancelled(self):
111         return False
112
113     def is_active(self):
114         return False
115
116     def time_remaining(self):
117         return None
118
119     def running(self):
120         return False
121
122     def done(self):
123         return True
124
125     def result(self, ignored_timeout=None):
126         raise self._exception
127
128     def exception(self, ignored_timeout=None):
129         return self._exception
130
131     def traceback(self, ignored_timeout=None):
132         return self._traceback
133
134     def add_callback(self, unused_callback):
135         return False
136
137     def add_done_callback(self, fn):
138         fn(self)
139
140     def __iter__(self):
141         return self
142
143     def __next__(self):
144         raise self._exception
145
146     def next(self):
147         return self.__next__()
148
149
150 class _UnaryOutcome(grpc.Call, grpc.Future):
151
152     def __init__(self, response, call):
153         self._response = response
154         self._call = call
155
156     def initial_metadata(self):
157         return self._call.initial_metadata()
158
159     def trailing_metadata(self):
160         return self._call.trailing_metadata()
161
162     def code(self):
163         return self._call.code()
164
165     def details(self):
166         return self._call.details()
167
168     def is_active(self):
169         return self._call.is_active()
170
171     def time_remaining(self):
172         return self._call.time_remaining()
173
174     def cancel(self):
175         return self._call.cancel()
176
177     def add_callback(self, callback):
178         return self._call.add_callback(callback)
179
180     def cancelled(self):
181         return False
182
183     def running(self):
184         return False
185
186     def done(self):
187         return True
188
189     def result(self, ignored_timeout=None):
190         return self._response
191
192     def exception(self, ignored_timeout=None):
193         return None
194
195     def traceback(self, ignored_timeout=None):
196         return None
197
198     def add_done_callback(self, fn):
199         fn(self)
200
201
202 class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
203
204     def __init__(self, thunk, method, interceptor):
205         self._thunk = thunk
206         self._method = method
207         self._interceptor = interceptor
208
209     def __call__(self,
210                  request,
211                  timeout=None,
212                  metadata=None,
213                  credentials=None,
214                  wait_for_ready=None,
215                  compression=None):
216         response, ignored_call = self._with_call(request,
217                                                  timeout=timeout,
218                                                  metadata=metadata,
219                                                  credentials=credentials,
220                                                  wait_for_ready=wait_for_ready,
221                                                  compression=compression)
222         return response
223
224     def _with_call(self,
225                    request,
226                    timeout=None,
227                    metadata=None,
228                    credentials=None,
229                    wait_for_ready=None,
230                    compression=None):
231         client_call_details = _ClientCallDetails(self._method, timeout,
232                                                  metadata, credentials,
233                                                  wait_for_ready, compression)
234
235         def continuation(new_details, request):
236             (new_method, new_timeout, new_metadata, new_credentials,
237              new_wait_for_ready,
238              new_compression) = (_unwrap_client_call_details(
239                  new_details, client_call_details))
240             try:
241                 response, call = self._thunk(new_method).with_call(
242                     request,
243                     timeout=new_timeout,
244                     metadata=new_metadata,
245                     credentials=new_credentials,
246                     wait_for_ready=new_wait_for_ready,
247                     compression=new_compression)
248                 return _UnaryOutcome(response, call)
249             except grpc.RpcError as rpc_error:
250                 return rpc_error
251             except Exception as exception:  # pylint:disable=broad-except
252                 return _FailureOutcome(exception, sys.exc_info()[2])
253
254         call = self._interceptor.intercept_unary_unary(continuation,
255                                                        client_call_details,
256                                                        request)
257         return call.result(), call
258
259     def with_call(self,
260                   request,
261                   timeout=None,
262                   metadata=None,
263                   credentials=None,
264                   wait_for_ready=None,
265                   compression=None):
266         return self._with_call(request,
267                                timeout=timeout,
268                                metadata=metadata,
269                                credentials=credentials,
270                                wait_for_ready=wait_for_ready,
271                                compression=compression)
272
273     def future(self,
274                request,
275                timeout=None,
276                metadata=None,
277                credentials=None,
278                wait_for_ready=None,
279                compression=None):
280         client_call_details = _ClientCallDetails(self._method, timeout,
281                                                  metadata, credentials,
282                                                  wait_for_ready, compression)
283
284         def continuation(new_details, request):
285             (new_method, new_timeout, new_metadata, new_credentials,
286              new_wait_for_ready,
287              new_compression) = (_unwrap_client_call_details(
288                  new_details, client_call_details))
289             return self._thunk(new_method).future(
290                 request,
291                 timeout=new_timeout,
292                 metadata=new_metadata,
293                 credentials=new_credentials,
294                 wait_for_ready=new_wait_for_ready,
295                 compression=new_compression)
296
297         try:
298             return self._interceptor.intercept_unary_unary(
299                 continuation, client_call_details, request)
300         except Exception as exception:  # pylint:disable=broad-except
301             return _FailureOutcome(exception, sys.exc_info()[2])
302
303
304 class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
305
306     def __init__(self, thunk, method, interceptor):
307         self._thunk = thunk
308         self._method = method
309         self._interceptor = interceptor
310
311     def __call__(self,
312                  request,
313                  timeout=None,
314                  metadata=None,
315                  credentials=None,
316                  wait_for_ready=None,
317                  compression=None):
318         client_call_details = _ClientCallDetails(self._method, timeout,
319                                                  metadata, credentials,
320                                                  wait_for_ready, compression)
321
322         def continuation(new_details, request):
323             (new_method, new_timeout, new_metadata, new_credentials,
324              new_wait_for_ready,
325              new_compression) = (_unwrap_client_call_details(
326                  new_details, client_call_details))
327             return self._thunk(new_method)(request,
328                                            timeout=new_timeout,
329                                            metadata=new_metadata,
330                                            credentials=new_credentials,
331                                            wait_for_ready=new_wait_for_ready,
332                                            compression=new_compression)
333
334         try:
335             return self._interceptor.intercept_unary_stream(
336                 continuation, client_call_details, request)
337         except Exception as exception:  # pylint:disable=broad-except
338             return _FailureOutcome(exception, sys.exc_info()[2])
339
340
341 class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
342
343     def __init__(self, thunk, method, interceptor):
344         self._thunk = thunk
345         self._method = method
346         self._interceptor = interceptor
347
348     def __call__(self,
349                  request_iterator,
350                  timeout=None,
351                  metadata=None,
352                  credentials=None,
353                  wait_for_ready=None,
354                  compression=None):
355         response, ignored_call = self._with_call(request_iterator,
356                                                  timeout=timeout,
357                                                  metadata=metadata,
358                                                  credentials=credentials,
359                                                  wait_for_ready=wait_for_ready,
360                                                  compression=compression)
361         return response
362
363     def _with_call(self,
364                    request_iterator,
365                    timeout=None,
366                    metadata=None,
367                    credentials=None,
368                    wait_for_ready=None,
369                    compression=None):
370         client_call_details = _ClientCallDetails(self._method, timeout,
371                                                  metadata, credentials,
372                                                  wait_for_ready, compression)
373
374         def continuation(new_details, request_iterator):
375             (new_method, new_timeout, new_metadata, new_credentials,
376              new_wait_for_ready,
377              new_compression) = (_unwrap_client_call_details(
378                  new_details, client_call_details))
379             try:
380                 response, call = self._thunk(new_method).with_call(
381                     request_iterator,
382                     timeout=new_timeout,
383                     metadata=new_metadata,
384                     credentials=new_credentials,
385                     wait_for_ready=new_wait_for_ready,
386                     compression=new_compression)
387                 return _UnaryOutcome(response, call)
388             except grpc.RpcError as rpc_error:
389                 return rpc_error
390             except Exception as exception:  # pylint:disable=broad-except
391                 return _FailureOutcome(exception, sys.exc_info()[2])
392
393         call = self._interceptor.intercept_stream_unary(continuation,
394                                                         client_call_details,
395                                                         request_iterator)
396         return call.result(), call
397
398     def with_call(self,
399                   request_iterator,
400                   timeout=None,
401                   metadata=None,
402                   credentials=None,
403                   wait_for_ready=None,
404                   compression=None):
405         return self._with_call(request_iterator,
406                                timeout=timeout,
407                                metadata=metadata,
408                                credentials=credentials,
409                                wait_for_ready=wait_for_ready,
410                                compression=compression)
411
412     def future(self,
413                request_iterator,
414                timeout=None,
415                metadata=None,
416                credentials=None,
417                wait_for_ready=None,
418                compression=None):
419         client_call_details = _ClientCallDetails(self._method, timeout,
420                                                  metadata, credentials,
421                                                  wait_for_ready, compression)
422
423         def continuation(new_details, request_iterator):
424             (new_method, new_timeout, new_metadata, new_credentials,
425              new_wait_for_ready,
426              new_compression) = (_unwrap_client_call_details(
427                  new_details, client_call_details))
428             return self._thunk(new_method).future(
429                 request_iterator,
430                 timeout=new_timeout,
431                 metadata=new_metadata,
432                 credentials=new_credentials,
433                 wait_for_ready=new_wait_for_ready,
434                 compression=new_compression)
435
436         try:
437             return self._interceptor.intercept_stream_unary(
438                 continuation, client_call_details, request_iterator)
439         except Exception as exception:  # pylint:disable=broad-except
440             return _FailureOutcome(exception, sys.exc_info()[2])
441
442
443 class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
444
445     def __init__(self, thunk, method, interceptor):
446         self._thunk = thunk
447         self._method = method
448         self._interceptor = interceptor
449
450     def __call__(self,
451                  request_iterator,
452                  timeout=None,
453                  metadata=None,
454                  credentials=None,
455                  wait_for_ready=None,
456                  compression=None):
457         client_call_details = _ClientCallDetails(self._method, timeout,
458                                                  metadata, credentials,
459                                                  wait_for_ready, compression)
460
461         def continuation(new_details, request_iterator):
462             (new_method, new_timeout, new_metadata, new_credentials,
463              new_wait_for_ready,
464              new_compression) = (_unwrap_client_call_details(
465                  new_details, client_call_details))
466             return self._thunk(new_method)(request_iterator,
467                                            timeout=new_timeout,
468                                            metadata=new_metadata,
469                                            credentials=new_credentials,
470                                            wait_for_ready=new_wait_for_ready,
471                                            compression=new_compression)
472
473         try:
474             return self._interceptor.intercept_stream_stream(
475                 continuation, client_call_details, request_iterator)
476         except Exception as exception:  # pylint:disable=broad-except
477             return _FailureOutcome(exception, sys.exc_info()[2])
478
479
480 class _Channel(grpc.Channel):
481
482     def __init__(self, channel, interceptor):
483         self._channel = channel
484         self._interceptor = interceptor
485
486     def subscribe(self, callback, try_to_connect=False):
487         self._channel.subscribe(callback, try_to_connect=try_to_connect)
488
489     def unsubscribe(self, callback):
490         self._channel.unsubscribe(callback)
491
492     def unary_unary(self,
493                     method,
494                     request_serializer=None,
495                     response_deserializer=None):
496         thunk = lambda m: self._channel.unary_unary(m, request_serializer,
497                                                     response_deserializer)
498         if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
499             return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
500         else:
501             return thunk(method)
502
503     def unary_stream(self,
504                      method,
505                      request_serializer=None,
506                      response_deserializer=None):
507         thunk = lambda m: self._channel.unary_stream(m, request_serializer,
508                                                      response_deserializer)
509         if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
510             return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
511         else:
512             return thunk(method)
513
514     def stream_unary(self,
515                      method,
516                      request_serializer=None,
517                      response_deserializer=None):
518         thunk = lambda m: self._channel.stream_unary(m, request_serializer,
519                                                      response_deserializer)
520         if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
521             return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
522         else:
523             return thunk(method)
524
525     def stream_stream(self,
526                       method,
527                       request_serializer=None,
528                       response_deserializer=None):
529         thunk = lambda m: self._channel.stream_stream(m, request_serializer,
530                                                       response_deserializer)
531         if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
532             return _StreamStreamMultiCallable(thunk, method, self._interceptor)
533         else:
534             return thunk(method)
535
536     def _close(self):
537         self._channel.close()
538
539     def __enter__(self):
540         return self
541
542     def __exit__(self, exc_type, exc_val, exc_tb):
543         self._close()
544         return False
545
546     def close(self):
547         self._channel.close()
548
549
550 def intercept_channel(channel, *interceptors):
551     for interceptor in reversed(list(interceptors)):
552         if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
553            not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
554            not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
555            not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
556             raise TypeError('interceptor must be '
557                             'grpc.UnaryUnaryClientInterceptor or '
558                             'grpc.UnaryStreamClientInterceptor or '
559                             'grpc.StreamUnaryClientInterceptor or '
560                             'grpc.StreamStreamClientInterceptor or ')
561         channel = _Channel(channel, interceptor)
562     return channel