1 # Copyright 2017 gRPC authors.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Implementation of gRPC Python interceptors."""
22 class _ServicePipeline(object):
24 def __init__(self, interceptors):
25 self.interceptors = tuple(interceptors)
27 def _continuation(self, thunk, index):
28 return lambda context: self._intercept_at(thunk, index, context)
30 def _intercept_at(self, thunk, index, context):
31 if index < len(self.interceptors):
32 interceptor = self.interceptors[index]
33 thunk = self._continuation(thunk, index + 1)
34 return interceptor.intercept_service(thunk, context)
38 def execute(self, thunk, context):
39 return self._intercept_at(thunk, 0, context)
42 def service_pipeline(interceptors):
43 return _ServicePipeline(interceptors) if interceptors else None
46 class _ClientCallDetails(
47 collections.namedtuple('_ClientCallDetails',
48 ('method', 'timeout', 'metadata', 'credentials',
49 'wait_for_ready', 'compression')),
50 grpc.ClientCallDetails):
54 def _unwrap_client_call_details(call_details, default_details):
56 method = call_details.method
57 except AttributeError:
58 method = default_details.method
61 timeout = call_details.timeout
62 except AttributeError:
63 timeout = default_details.timeout
66 metadata = call_details.metadata
67 except AttributeError:
68 metadata = default_details.metadata
71 credentials = call_details.credentials
72 except AttributeError:
73 credentials = default_details.credentials
76 wait_for_ready = call_details.wait_for_ready
77 except AttributeError:
78 wait_for_ready = default_details.wait_for_ready
81 compression = call_details.compression
82 except AttributeError:
83 compression = default_details.compression
85 return method, timeout, metadata, credentials, wait_for_ready, compression
88 class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
90 def __init__(self, exception, traceback):
91 super(_FailureOutcome, self).__init__()
92 self._exception = exception
93 self._traceback = traceback
95 def initial_metadata(self):
98 def trailing_metadata(self):
102 return grpc.StatusCode.INTERNAL
105 return 'Exception raised while intercepting the RPC'
116 def time_remaining(self):
125 def result(self, ignored_timeout=None):
126 raise self._exception
128 def exception(self, ignored_timeout=None):
129 return self._exception
131 def traceback(self, ignored_timeout=None):
132 return self._traceback
134 def add_callback(self, unused_callback):
137 def add_done_callback(self, fn):
144 raise self._exception
147 return self.__next__()
150 class _UnaryOutcome(grpc.Call, grpc.Future):
152 def __init__(self, response, call):
153 self._response = response
156 def initial_metadata(self):
157 return self._call.initial_metadata()
159 def trailing_metadata(self):
160 return self._call.trailing_metadata()
163 return self._call.code()
166 return self._call.details()
169 return self._call.is_active()
171 def time_remaining(self):
172 return self._call.time_remaining()
175 return self._call.cancel()
177 def add_callback(self, callback):
178 return self._call.add_callback(callback)
189 def result(self, ignored_timeout=None):
190 return self._response
192 def exception(self, ignored_timeout=None):
195 def traceback(self, ignored_timeout=None):
198 def add_done_callback(self, fn):
202 class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
204 def __init__(self, thunk, method, interceptor):
206 self._method = method
207 self._interceptor = interceptor
216 response, ignored_call = self._with_call(request,
219 credentials=credentials,
220 wait_for_ready=wait_for_ready,
221 compression=compression)
231 client_call_details = _ClientCallDetails(self._method, timeout,
232 metadata, credentials,
233 wait_for_ready, compression)
235 def continuation(new_details, request):
236 (new_method, new_timeout, new_metadata, new_credentials,
238 new_compression) = (_unwrap_client_call_details(
239 new_details, client_call_details))
241 response, call = self._thunk(new_method).with_call(
244 metadata=new_metadata,
245 credentials=new_credentials,
246 wait_for_ready=new_wait_for_ready,
247 compression=new_compression)
248 return _UnaryOutcome(response, call)
249 except grpc.RpcError as rpc_error:
251 except Exception as exception: # pylint:disable=broad-except
252 return _FailureOutcome(exception, sys.exc_info()[2])
254 call = self._interceptor.intercept_unary_unary(continuation,
257 return call.result(), call
266 return self._with_call(request,
269 credentials=credentials,
270 wait_for_ready=wait_for_ready,
271 compression=compression)
280 client_call_details = _ClientCallDetails(self._method, timeout,
281 metadata, credentials,
282 wait_for_ready, compression)
284 def continuation(new_details, request):
285 (new_method, new_timeout, new_metadata, new_credentials,
287 new_compression) = (_unwrap_client_call_details(
288 new_details, client_call_details))
289 return self._thunk(new_method).future(
292 metadata=new_metadata,
293 credentials=new_credentials,
294 wait_for_ready=new_wait_for_ready,
295 compression=new_compression)
298 return self._interceptor.intercept_unary_unary(
299 continuation, client_call_details, request)
300 except Exception as exception: # pylint:disable=broad-except
301 return _FailureOutcome(exception, sys.exc_info()[2])
304 class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
306 def __init__(self, thunk, method, interceptor):
308 self._method = method
309 self._interceptor = interceptor
318 client_call_details = _ClientCallDetails(self._method, timeout,
319 metadata, credentials,
320 wait_for_ready, compression)
322 def continuation(new_details, request):
323 (new_method, new_timeout, new_metadata, new_credentials,
325 new_compression) = (_unwrap_client_call_details(
326 new_details, client_call_details))
327 return self._thunk(new_method)(request,
329 metadata=new_metadata,
330 credentials=new_credentials,
331 wait_for_ready=new_wait_for_ready,
332 compression=new_compression)
335 return self._interceptor.intercept_unary_stream(
336 continuation, client_call_details, request)
337 except Exception as exception: # pylint:disable=broad-except
338 return _FailureOutcome(exception, sys.exc_info()[2])
341 class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
343 def __init__(self, thunk, method, interceptor):
345 self._method = method
346 self._interceptor = interceptor
355 response, ignored_call = self._with_call(request_iterator,
358 credentials=credentials,
359 wait_for_ready=wait_for_ready,
360 compression=compression)
370 client_call_details = _ClientCallDetails(self._method, timeout,
371 metadata, credentials,
372 wait_for_ready, compression)
374 def continuation(new_details, request_iterator):
375 (new_method, new_timeout, new_metadata, new_credentials,
377 new_compression) = (_unwrap_client_call_details(
378 new_details, client_call_details))
380 response, call = self._thunk(new_method).with_call(
383 metadata=new_metadata,
384 credentials=new_credentials,
385 wait_for_ready=new_wait_for_ready,
386 compression=new_compression)
387 return _UnaryOutcome(response, call)
388 except grpc.RpcError as rpc_error:
390 except Exception as exception: # pylint:disable=broad-except
391 return _FailureOutcome(exception, sys.exc_info()[2])
393 call = self._interceptor.intercept_stream_unary(continuation,
396 return call.result(), call
405 return self._with_call(request_iterator,
408 credentials=credentials,
409 wait_for_ready=wait_for_ready,
410 compression=compression)
419 client_call_details = _ClientCallDetails(self._method, timeout,
420 metadata, credentials,
421 wait_for_ready, compression)
423 def continuation(new_details, request_iterator):
424 (new_method, new_timeout, new_metadata, new_credentials,
426 new_compression) = (_unwrap_client_call_details(
427 new_details, client_call_details))
428 return self._thunk(new_method).future(
431 metadata=new_metadata,
432 credentials=new_credentials,
433 wait_for_ready=new_wait_for_ready,
434 compression=new_compression)
437 return self._interceptor.intercept_stream_unary(
438 continuation, client_call_details, request_iterator)
439 except Exception as exception: # pylint:disable=broad-except
440 return _FailureOutcome(exception, sys.exc_info()[2])
443 class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
445 def __init__(self, thunk, method, interceptor):
447 self._method = method
448 self._interceptor = interceptor
457 client_call_details = _ClientCallDetails(self._method, timeout,
458 metadata, credentials,
459 wait_for_ready, compression)
461 def continuation(new_details, request_iterator):
462 (new_method, new_timeout, new_metadata, new_credentials,
464 new_compression) = (_unwrap_client_call_details(
465 new_details, client_call_details))
466 return self._thunk(new_method)(request_iterator,
468 metadata=new_metadata,
469 credentials=new_credentials,
470 wait_for_ready=new_wait_for_ready,
471 compression=new_compression)
474 return self._interceptor.intercept_stream_stream(
475 continuation, client_call_details, request_iterator)
476 except Exception as exception: # pylint:disable=broad-except
477 return _FailureOutcome(exception, sys.exc_info()[2])
480 class _Channel(grpc.Channel):
482 def __init__(self, channel, interceptor):
483 self._channel = channel
484 self._interceptor = interceptor
486 def subscribe(self, callback, try_to_connect=False):
487 self._channel.subscribe(callback, try_to_connect=try_to_connect)
489 def unsubscribe(self, callback):
490 self._channel.unsubscribe(callback)
492 def unary_unary(self,
494 request_serializer=None,
495 response_deserializer=None):
496 thunk = lambda m: self._channel.unary_unary(m, request_serializer,
497 response_deserializer)
498 if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
499 return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
503 def unary_stream(self,
505 request_serializer=None,
506 response_deserializer=None):
507 thunk = lambda m: self._channel.unary_stream(m, request_serializer,
508 response_deserializer)
509 if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
510 return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
514 def stream_unary(self,
516 request_serializer=None,
517 response_deserializer=None):
518 thunk = lambda m: self._channel.stream_unary(m, request_serializer,
519 response_deserializer)
520 if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
521 return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
525 def stream_stream(self,
527 request_serializer=None,
528 response_deserializer=None):
529 thunk = lambda m: self._channel.stream_stream(m, request_serializer,
530 response_deserializer)
531 if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
532 return _StreamStreamMultiCallable(thunk, method, self._interceptor)
537 self._channel.close()
542 def __exit__(self, exc_type, exc_val, exc_tb):
547 self._channel.close()
550 def intercept_channel(channel, *interceptors):
551 for interceptor in reversed(list(interceptors)):
552 if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
553 not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
554 not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
555 not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
556 raise TypeError('interceptor must be '
557 'grpc.UnaryUnaryClientInterceptor or '
558 'grpc.UnaryStreamClientInterceptor or '
559 'grpc.StreamUnaryClientInterceptor or '
560 'grpc.StreamStreamClientInterceptor or ')
561 channel = _Channel(channel, interceptor)