Imported Upstream version 1.27.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests_aio / unit / interceptor_test.py
1 # Copyright 2019 The gRPC Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import asyncio
15 import logging
16 import unittest
17
18 import grpc
19
20 from grpc.experimental import aio
21 from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY
22 from tests_aio.unit import _constants
23 from tests_aio.unit import _common
24 from tests_aio.unit._test_base import AioTestBase
25 from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
26
27 _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
28 _INITIAL_METADATA_TO_INJECT = (
29     (_INITIAL_METADATA_KEY, 'extra info'),
30     (_TRAILING_METADATA_KEY, b'\x13\x37'),
31 )
32
33
34 class TestUnaryUnaryClientInterceptor(AioTestBase):
35
36     async def setUp(self):
37         self._server_target, self._server = await start_test_server()
38
39     async def tearDown(self):
40         await self._server.stop(None)
41
42     def test_invalid_interceptor(self):
43
44         class InvalidInterceptor:
45             """Just an invalid Interceptor"""
46
47         with self.assertRaises(ValueError):
48             aio.insecure_channel("", interceptors=[InvalidInterceptor()])
49
50     async def test_executed_right_order(self):
51
52         interceptors_executed = []
53
54         class Interceptor(aio.UnaryUnaryClientInterceptor):
55             """Interceptor used for testing if the interceptor is being called"""
56
57             async def intercept_unary_unary(self, continuation,
58                                             client_call_details, request):
59                 interceptors_executed.append(self)
60                 call = await continuation(client_call_details, request)
61                 return call
62
63         interceptors = [Interceptor() for i in range(2)]
64
65         async with aio.insecure_channel(self._server_target,
66                                         interceptors=interceptors) as channel:
67             multicallable = channel.unary_unary(
68                 '/grpc.testing.TestService/UnaryCall',
69                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
70                 response_deserializer=messages_pb2.SimpleResponse.FromString)
71             call = multicallable(messages_pb2.SimpleRequest())
72             response = await call
73
74             # Check that all interceptors were executed, and were executed
75             # in the right order.
76             self.assertSequenceEqual(interceptors_executed, interceptors)
77
78             self.assertIsInstance(response, messages_pb2.SimpleResponse)
79
80     @unittest.expectedFailure
81     # TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is
82     # implemented in the client-side, this test must be implemented.
83     def test_modify_metadata(self):
84         raise NotImplementedError()
85
86     @unittest.expectedFailure
87     # TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is
88     # implemented in the client-side, this test must be implemented.
89     def test_modify_credentials(self):
90         raise NotImplementedError()
91
92     async def test_status_code_Ok(self):
93
94         class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor):
95             """Interceptor used for observing status code Ok returned by the RPC"""
96
97             def __init__(self):
98                 self.status_code_Ok_observed = False
99
100             async def intercept_unary_unary(self, continuation,
101                                             client_call_details, request):
102                 call = await continuation(client_call_details, request)
103                 code = await call.code()
104                 if code == grpc.StatusCode.OK:
105                     self.status_code_Ok_observed = True
106
107                 return call
108
109         interceptor = StatusCodeOkInterceptor()
110
111         async with aio.insecure_channel(self._server_target,
112                                         interceptors=[interceptor]) as channel:
113
114             # when no error StatusCode.OK must be observed
115             multicallable = channel.unary_unary(
116                 '/grpc.testing.TestService/UnaryCall',
117                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
118                 response_deserializer=messages_pb2.SimpleResponse.FromString)
119
120             await multicallable(messages_pb2.SimpleRequest())
121
122             self.assertTrue(interceptor.status_code_Ok_observed)
123
124     async def test_add_timeout(self):
125
126         class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor):
127             """Interceptor used for adding a timeout to the RPC"""
128
129             async def intercept_unary_unary(self, continuation,
130                                             client_call_details, request):
131                 new_client_call_details = aio.ClientCallDetails(
132                     method=client_call_details.method,
133                     timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
134                     metadata=client_call_details.metadata,
135                     credentials=client_call_details.credentials)
136                 return await continuation(new_client_call_details, request)
137
138         interceptor = TimeoutInterceptor()
139
140         async with aio.insecure_channel(self._server_target,
141                                         interceptors=[interceptor]) as channel:
142
143             multicallable = channel.unary_unary(
144                 '/grpc.testing.TestService/UnaryCallWithSleep',
145                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
146                 response_deserializer=messages_pb2.SimpleResponse.FromString)
147
148             call = multicallable(messages_pb2.SimpleRequest())
149
150             with self.assertRaises(aio.AioRpcError) as exception_context:
151                 await call
152
153             self.assertEqual(exception_context.exception.code(),
154                              grpc.StatusCode.DEADLINE_EXCEEDED)
155
156             self.assertTrue(call.done())
157             self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
158                              call.code())
159
160     async def test_retry(self):
161
162         class RetryInterceptor(aio.UnaryUnaryClientInterceptor):
163             """Simulates a Retry Interceptor which ends up by making 
164             two RPC calls."""
165
166             def __init__(self):
167                 self.calls = []
168
169             async def intercept_unary_unary(self, continuation,
170                                             client_call_details, request):
171
172                 new_client_call_details = aio.ClientCallDetails(
173                     method=client_call_details.method,
174                     timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
175                     metadata=client_call_details.metadata,
176                     credentials=client_call_details.credentials)
177
178                 try:
179                     call = await continuation(new_client_call_details, request)
180                     await call
181                 except grpc.RpcError:
182                     pass
183
184                 self.calls.append(call)
185
186                 new_client_call_details = aio.ClientCallDetails(
187                     method=client_call_details.method,
188                     timeout=None,
189                     metadata=client_call_details.metadata,
190                     credentials=client_call_details.credentials)
191
192                 call = await continuation(new_client_call_details, request)
193                 self.calls.append(call)
194                 return call
195
196         interceptor = RetryInterceptor()
197
198         async with aio.insecure_channel(self._server_target,
199                                         interceptors=[interceptor]) as channel:
200
201             multicallable = channel.unary_unary(
202                 '/grpc.testing.TestService/UnaryCallWithSleep',
203                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
204                 response_deserializer=messages_pb2.SimpleResponse.FromString)
205
206             call = multicallable(messages_pb2.SimpleRequest())
207
208             await call
209
210             self.assertEqual(grpc.StatusCode.OK, await call.code())
211
212             # Check that two calls were made, first one finishing with
213             # a deadline and second one finishing ok..
214             self.assertEqual(len(interceptor.calls), 2)
215             self.assertEqual(await interceptor.calls[0].code(),
216                              grpc.StatusCode.DEADLINE_EXCEEDED)
217             self.assertEqual(await interceptor.calls[1].code(),
218                              grpc.StatusCode.OK)
219
220     async def test_rpcresponse(self):
221
222         class Interceptor(aio.UnaryUnaryClientInterceptor):
223             """Raw responses are seen as reegular calls"""
224
225             async def intercept_unary_unary(self, continuation,
226                                             client_call_details, request):
227                 call = await continuation(client_call_details, request)
228                 response = await call
229                 return call
230
231         class ResponseInterceptor(aio.UnaryUnaryClientInterceptor):
232             """Return a raw response"""
233             response = messages_pb2.SimpleResponse()
234
235             async def intercept_unary_unary(self, continuation,
236                                             client_call_details, request):
237                 return ResponseInterceptor.response
238
239         interceptor, interceptor_response = Interceptor(), ResponseInterceptor()
240
241         async with aio.insecure_channel(
242                 self._server_target,
243                 interceptors=[interceptor, interceptor_response]) as channel:
244
245             multicallable = channel.unary_unary(
246                 '/grpc.testing.TestService/UnaryCall',
247                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
248                 response_deserializer=messages_pb2.SimpleResponse.FromString)
249
250             call = multicallable(messages_pb2.SimpleRequest())
251             response = await call
252
253             # Check that the response returned is the one returned by the
254             # interceptor
255             self.assertEqual(id(response), id(ResponseInterceptor.response))
256
257             # Check all of the UnaryUnaryCallResponse attributes
258             self.assertTrue(call.done())
259             self.assertFalse(call.cancel())
260             self.assertFalse(call.cancelled())
261             self.assertEqual(await call.code(), grpc.StatusCode.OK)
262             self.assertEqual(await call.details(), '')
263             self.assertEqual(await call.initial_metadata(), None)
264             self.assertEqual(await call.trailing_metadata(), None)
265             self.assertEqual(await call.debug_error_string(), None)
266
267
268 class TestInterceptedUnaryUnaryCall(AioTestBase):
269
270     async def setUp(self):
271         self._server_target, self._server = await start_test_server()
272
273     async def tearDown(self):
274         await self._server.stop(None)
275
276     async def test_call_ok(self):
277
278         class Interceptor(aio.UnaryUnaryClientInterceptor):
279
280             async def intercept_unary_unary(self, continuation,
281                                             client_call_details, request):
282                 call = await continuation(client_call_details, request)
283                 return call
284
285         async with aio.insecure_channel(self._server_target,
286                                         interceptors=[Interceptor()
287                                                      ]) as channel:
288
289             multicallable = channel.unary_unary(
290                 '/grpc.testing.TestService/UnaryCall',
291                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
292                 response_deserializer=messages_pb2.SimpleResponse.FromString)
293             call = multicallable(messages_pb2.SimpleRequest())
294             response = await call
295
296             self.assertTrue(call.done())
297             self.assertFalse(call.cancelled())
298             self.assertEqual(type(response), messages_pb2.SimpleResponse)
299             self.assertEqual(await call.code(), grpc.StatusCode.OK)
300             self.assertEqual(await call.details(), '')
301             self.assertEqual(await call.initial_metadata(), ())
302             self.assertEqual(await call.trailing_metadata(), ())
303
304     async def test_call_ok_awaited(self):
305
306         class Interceptor(aio.UnaryUnaryClientInterceptor):
307
308             async def intercept_unary_unary(self, continuation,
309                                             client_call_details, request):
310                 call = await continuation(client_call_details, request)
311                 await call
312                 return call
313
314         async with aio.insecure_channel(self._server_target,
315                                         interceptors=[Interceptor()
316                                                      ]) as channel:
317
318             multicallable = channel.unary_unary(
319                 '/grpc.testing.TestService/UnaryCall',
320                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
321                 response_deserializer=messages_pb2.SimpleResponse.FromString)
322             call = multicallable(messages_pb2.SimpleRequest())
323             response = await call
324
325             self.assertTrue(call.done())
326             self.assertFalse(call.cancelled())
327             self.assertEqual(type(response), messages_pb2.SimpleResponse)
328             self.assertEqual(await call.code(), grpc.StatusCode.OK)
329             self.assertEqual(await call.details(), '')
330             self.assertEqual(await call.initial_metadata(), ())
331             self.assertEqual(await call.trailing_metadata(), ())
332
333     async def test_call_rpc_error(self):
334
335         class Interceptor(aio.UnaryUnaryClientInterceptor):
336
337             async def intercept_unary_unary(self, continuation,
338                                             client_call_details, request):
339                 call = await continuation(client_call_details, request)
340                 return call
341
342         async with aio.insecure_channel(self._server_target,
343                                         interceptors=[Interceptor()
344                                                      ]) as channel:
345
346             multicallable = channel.unary_unary(
347                 '/grpc.testing.TestService/UnaryCallWithSleep',
348                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
349                 response_deserializer=messages_pb2.SimpleResponse.FromString)
350
351             call = multicallable(
352                 messages_pb2.SimpleRequest(),
353                 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
354
355             with self.assertRaises(aio.AioRpcError) as exception_context:
356                 await call
357
358             self.assertTrue(call.done())
359             self.assertFalse(call.cancelled())
360             self.assertEqual(await call.code(),
361                              grpc.StatusCode.DEADLINE_EXCEEDED)
362             self.assertEqual(await call.details(), 'Deadline Exceeded')
363             self.assertEqual(await call.initial_metadata(), ())
364             self.assertEqual(await call.trailing_metadata(), ())
365
366     async def test_call_rpc_error_awaited(self):
367
368         class Interceptor(aio.UnaryUnaryClientInterceptor):
369
370             async def intercept_unary_unary(self, continuation,
371                                             client_call_details, request):
372                 call = await continuation(client_call_details, request)
373                 await call
374                 return call
375
376         async with aio.insecure_channel(self._server_target,
377                                         interceptors=[Interceptor()
378                                                      ]) as channel:
379
380             multicallable = channel.unary_unary(
381                 '/grpc.testing.TestService/UnaryCallWithSleep',
382                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
383                 response_deserializer=messages_pb2.SimpleResponse.FromString)
384
385             call = multicallable(
386                 messages_pb2.SimpleRequest(),
387                 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
388
389             with self.assertRaises(aio.AioRpcError) as exception_context:
390                 await call
391
392             self.assertTrue(call.done())
393             self.assertFalse(call.cancelled())
394             self.assertEqual(await call.code(),
395                              grpc.StatusCode.DEADLINE_EXCEEDED)
396             self.assertEqual(await call.details(), 'Deadline Exceeded')
397             self.assertEqual(await call.initial_metadata(), ())
398             self.assertEqual(await call.trailing_metadata(), ())
399
400     async def test_cancel_before_rpc(self):
401
402         interceptor_reached = asyncio.Event()
403         wait_for_ever = self.loop.create_future()
404
405         class Interceptor(aio.UnaryUnaryClientInterceptor):
406
407             async def intercept_unary_unary(self, continuation,
408                                             client_call_details, request):
409                 interceptor_reached.set()
410                 await wait_for_ever
411
412         async with aio.insecure_channel(self._server_target,
413                                         interceptors=[Interceptor()
414                                                      ]) as channel:
415
416             multicallable = channel.unary_unary(
417                 '/grpc.testing.TestService/UnaryCall',
418                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
419                 response_deserializer=messages_pb2.SimpleResponse.FromString)
420             call = multicallable(messages_pb2.SimpleRequest())
421
422             self.assertFalse(call.cancelled())
423             self.assertFalse(call.done())
424
425             await interceptor_reached.wait()
426             self.assertTrue(call.cancel())
427
428             with self.assertRaises(asyncio.CancelledError):
429                 await call
430
431             self.assertTrue(call.cancelled())
432             self.assertTrue(call.done())
433             self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
434             self.assertEqual(await call.details(),
435                              _LOCAL_CANCEL_DETAILS_EXPECTATION)
436             self.assertEqual(await call.initial_metadata(), None)
437             self.assertEqual(await call.trailing_metadata(), None)
438
439     async def test_cancel_after_rpc(self):
440
441         interceptor_reached = asyncio.Event()
442         wait_for_ever = self.loop.create_future()
443
444         class Interceptor(aio.UnaryUnaryClientInterceptor):
445
446             async def intercept_unary_unary(self, continuation,
447                                             client_call_details, request):
448                 call = await continuation(client_call_details, request)
449                 await call
450                 interceptor_reached.set()
451                 await wait_for_ever
452
453         async with aio.insecure_channel(self._server_target,
454                                         interceptors=[Interceptor()
455                                                      ]) as channel:
456
457             multicallable = channel.unary_unary(
458                 '/grpc.testing.TestService/UnaryCall',
459                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
460                 response_deserializer=messages_pb2.SimpleResponse.FromString)
461             call = multicallable(messages_pb2.SimpleRequest())
462
463             self.assertFalse(call.cancelled())
464             self.assertFalse(call.done())
465
466             await interceptor_reached.wait()
467             self.assertTrue(call.cancel())
468
469             with self.assertRaises(asyncio.CancelledError):
470                 await call
471
472             self.assertTrue(call.cancelled())
473             self.assertTrue(call.done())
474             self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
475             self.assertEqual(await call.details(),
476                              _LOCAL_CANCEL_DETAILS_EXPECTATION)
477             self.assertEqual(await call.initial_metadata(), None)
478             self.assertEqual(await call.trailing_metadata(), None)
479
480     async def test_cancel_inside_interceptor_after_rpc_awaiting(self):
481
482         class Interceptor(aio.UnaryUnaryClientInterceptor):
483
484             async def intercept_unary_unary(self, continuation,
485                                             client_call_details, request):
486                 call = await continuation(client_call_details, request)
487                 call.cancel()
488                 await call
489                 return call
490
491         async with aio.insecure_channel(self._server_target,
492                                         interceptors=[Interceptor()
493                                                      ]) as channel:
494
495             multicallable = channel.unary_unary(
496                 '/grpc.testing.TestService/UnaryCall',
497                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
498                 response_deserializer=messages_pb2.SimpleResponse.FromString)
499             call = multicallable(messages_pb2.SimpleRequest())
500
501             with self.assertRaises(asyncio.CancelledError):
502                 await call
503
504             self.assertTrue(call.cancelled())
505             self.assertTrue(call.done())
506             self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
507             self.assertEqual(await call.details(),
508                              _LOCAL_CANCEL_DETAILS_EXPECTATION)
509             self.assertEqual(await call.initial_metadata(), None)
510             self.assertEqual(await call.trailing_metadata(), None)
511
512     async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self):
513
514         class Interceptor(aio.UnaryUnaryClientInterceptor):
515
516             async def intercept_unary_unary(self, continuation,
517                                             client_call_details, request):
518                 call = await continuation(client_call_details, request)
519                 call.cancel()
520                 return call
521
522         async with aio.insecure_channel(self._server_target,
523                                         interceptors=[Interceptor()
524                                                      ]) as channel:
525
526             multicallable = channel.unary_unary(
527                 '/grpc.testing.TestService/UnaryCall',
528                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
529                 response_deserializer=messages_pb2.SimpleResponse.FromString)
530             call = multicallable(messages_pb2.SimpleRequest())
531
532             with self.assertRaises(asyncio.CancelledError):
533                 await call
534
535             self.assertTrue(call.cancelled())
536             self.assertTrue(call.done())
537             self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
538             self.assertEqual(await call.details(),
539                              _LOCAL_CANCEL_DETAILS_EXPECTATION)
540             self.assertEqual(await call.initial_metadata(), tuple())
541             self.assertEqual(await call.trailing_metadata(), None)
542
543     async def test_initial_metadata_modification(self):
544
545         class Interceptor(aio.UnaryUnaryClientInterceptor):
546
547             async def intercept_unary_unary(self, continuation,
548                                             client_call_details, request):
549                 new_details = aio.ClientCallDetails(
550                     method=client_call_details.method,
551                     timeout=client_call_details.timeout,
552                     metadata=client_call_details.metadata +
553                     _INITIAL_METADATA_TO_INJECT,
554                     credentials=client_call_details.credentials,
555                 )
556                 return await continuation(new_details, request)
557
558         async with aio.insecure_channel(self._server_target,
559                                         interceptors=[Interceptor()
560                                                      ]) as channel:
561             stub = test_pb2_grpc.TestServiceStub(channel)
562             call = stub.UnaryCall(messages_pb2.SimpleRequest())
563
564             # Expected to see the echoed initial metadata
565             self.assertTrue(
566                 _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await
567                                        call.initial_metadata()))
568
569             # Expected to see the echoed trailing metadata
570             self.assertTrue(
571                 _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await
572                                        call.trailing_metadata()))
573
574             self.assertEqual(await call.code(), grpc.StatusCode.OK)
575
576
577 if __name__ == '__main__':
578     logging.basicConfig()
579     unittest.main(verbosity=2)