1 # Copyright 2019 The gRPC Authors.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
20 from grpc.experimental import aio
21 from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY
22 from tests_aio.unit import _constants
23 from tests_aio.unit import _common
24 from tests_aio.unit._test_base import AioTestBase
25 from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
27 _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
28 _INITIAL_METADATA_TO_INJECT = (
29 (_INITIAL_METADATA_KEY, 'extra info'),
30 (_TRAILING_METADATA_KEY, b'\x13\x37'),
34 class TestUnaryUnaryClientInterceptor(AioTestBase):
36 async def setUp(self):
37 self._server_target, self._server = await start_test_server()
39 async def tearDown(self):
40 await self._server.stop(None)
42 def test_invalid_interceptor(self):
44 class InvalidInterceptor:
45 """Just an invalid Interceptor"""
47 with self.assertRaises(ValueError):
48 aio.insecure_channel("", interceptors=[InvalidInterceptor()])
50 async def test_executed_right_order(self):
52 interceptors_executed = []
54 class Interceptor(aio.UnaryUnaryClientInterceptor):
55 """Interceptor used for testing if the interceptor is being called"""
57 async def intercept_unary_unary(self, continuation,
58 client_call_details, request):
59 interceptors_executed.append(self)
60 call = await continuation(client_call_details, request)
63 interceptors = [Interceptor() for i in range(2)]
65 async with aio.insecure_channel(self._server_target,
66 interceptors=interceptors) as channel:
67 multicallable = channel.unary_unary(
68 '/grpc.testing.TestService/UnaryCall',
69 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
70 response_deserializer=messages_pb2.SimpleResponse.FromString)
71 call = multicallable(messages_pb2.SimpleRequest())
74 # Check that all interceptors were executed, and were executed
76 self.assertSequenceEqual(interceptors_executed, interceptors)
78 self.assertIsInstance(response, messages_pb2.SimpleResponse)
80 @unittest.expectedFailure
81 # TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is
82 # implemented in the client-side, this test must be implemented.
83 def test_modify_metadata(self):
84 raise NotImplementedError()
86 @unittest.expectedFailure
87 # TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is
88 # implemented in the client-side, this test must be implemented.
89 def test_modify_credentials(self):
90 raise NotImplementedError()
92 async def test_status_code_Ok(self):
94 class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor):
95 """Interceptor used for observing status code Ok returned by the RPC"""
98 self.status_code_Ok_observed = False
100 async def intercept_unary_unary(self, continuation,
101 client_call_details, request):
102 call = await continuation(client_call_details, request)
103 code = await call.code()
104 if code == grpc.StatusCode.OK:
105 self.status_code_Ok_observed = True
109 interceptor = StatusCodeOkInterceptor()
111 async with aio.insecure_channel(self._server_target,
112 interceptors=[interceptor]) as channel:
114 # when no error StatusCode.OK must be observed
115 multicallable = channel.unary_unary(
116 '/grpc.testing.TestService/UnaryCall',
117 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
118 response_deserializer=messages_pb2.SimpleResponse.FromString)
120 await multicallable(messages_pb2.SimpleRequest())
122 self.assertTrue(interceptor.status_code_Ok_observed)
124 async def test_add_timeout(self):
126 class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor):
127 """Interceptor used for adding a timeout to the RPC"""
129 async def intercept_unary_unary(self, continuation,
130 client_call_details, request):
131 new_client_call_details = aio.ClientCallDetails(
132 method=client_call_details.method,
133 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
134 metadata=client_call_details.metadata,
135 credentials=client_call_details.credentials)
136 return await continuation(new_client_call_details, request)
138 interceptor = TimeoutInterceptor()
140 async with aio.insecure_channel(self._server_target,
141 interceptors=[interceptor]) as channel:
143 multicallable = channel.unary_unary(
144 '/grpc.testing.TestService/UnaryCallWithSleep',
145 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
146 response_deserializer=messages_pb2.SimpleResponse.FromString)
148 call = multicallable(messages_pb2.SimpleRequest())
150 with self.assertRaises(aio.AioRpcError) as exception_context:
153 self.assertEqual(exception_context.exception.code(),
154 grpc.StatusCode.DEADLINE_EXCEEDED)
156 self.assertTrue(call.done())
157 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
160 async def test_retry(self):
162 class RetryInterceptor(aio.UnaryUnaryClientInterceptor):
163 """Simulates a Retry Interceptor which ends up by making
169 async def intercept_unary_unary(self, continuation,
170 client_call_details, request):
172 new_client_call_details = aio.ClientCallDetails(
173 method=client_call_details.method,
174 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
175 metadata=client_call_details.metadata,
176 credentials=client_call_details.credentials)
179 call = await continuation(new_client_call_details, request)
181 except grpc.RpcError:
184 self.calls.append(call)
186 new_client_call_details = aio.ClientCallDetails(
187 method=client_call_details.method,
189 metadata=client_call_details.metadata,
190 credentials=client_call_details.credentials)
192 call = await continuation(new_client_call_details, request)
193 self.calls.append(call)
196 interceptor = RetryInterceptor()
198 async with aio.insecure_channel(self._server_target,
199 interceptors=[interceptor]) as channel:
201 multicallable = channel.unary_unary(
202 '/grpc.testing.TestService/UnaryCallWithSleep',
203 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
204 response_deserializer=messages_pb2.SimpleResponse.FromString)
206 call = multicallable(messages_pb2.SimpleRequest())
210 self.assertEqual(grpc.StatusCode.OK, await call.code())
212 # Check that two calls were made, first one finishing with
213 # a deadline and second one finishing ok..
214 self.assertEqual(len(interceptor.calls), 2)
215 self.assertEqual(await interceptor.calls[0].code(),
216 grpc.StatusCode.DEADLINE_EXCEEDED)
217 self.assertEqual(await interceptor.calls[1].code(),
220 async def test_rpcresponse(self):
222 class Interceptor(aio.UnaryUnaryClientInterceptor):
223 """Raw responses are seen as reegular calls"""
225 async def intercept_unary_unary(self, continuation,
226 client_call_details, request):
227 call = await continuation(client_call_details, request)
228 response = await call
231 class ResponseInterceptor(aio.UnaryUnaryClientInterceptor):
232 """Return a raw response"""
233 response = messages_pb2.SimpleResponse()
235 async def intercept_unary_unary(self, continuation,
236 client_call_details, request):
237 return ResponseInterceptor.response
239 interceptor, interceptor_response = Interceptor(), ResponseInterceptor()
241 async with aio.insecure_channel(
243 interceptors=[interceptor, interceptor_response]) as channel:
245 multicallable = channel.unary_unary(
246 '/grpc.testing.TestService/UnaryCall',
247 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
248 response_deserializer=messages_pb2.SimpleResponse.FromString)
250 call = multicallable(messages_pb2.SimpleRequest())
251 response = await call
253 # Check that the response returned is the one returned by the
255 self.assertEqual(id(response), id(ResponseInterceptor.response))
257 # Check all of the UnaryUnaryCallResponse attributes
258 self.assertTrue(call.done())
259 self.assertFalse(call.cancel())
260 self.assertFalse(call.cancelled())
261 self.assertEqual(await call.code(), grpc.StatusCode.OK)
262 self.assertEqual(await call.details(), '')
263 self.assertEqual(await call.initial_metadata(), None)
264 self.assertEqual(await call.trailing_metadata(), None)
265 self.assertEqual(await call.debug_error_string(), None)
268 class TestInterceptedUnaryUnaryCall(AioTestBase):
270 async def setUp(self):
271 self._server_target, self._server = await start_test_server()
273 async def tearDown(self):
274 await self._server.stop(None)
276 async def test_call_ok(self):
278 class Interceptor(aio.UnaryUnaryClientInterceptor):
280 async def intercept_unary_unary(self, continuation,
281 client_call_details, request):
282 call = await continuation(client_call_details, request)
285 async with aio.insecure_channel(self._server_target,
286 interceptors=[Interceptor()
289 multicallable = channel.unary_unary(
290 '/grpc.testing.TestService/UnaryCall',
291 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
292 response_deserializer=messages_pb2.SimpleResponse.FromString)
293 call = multicallable(messages_pb2.SimpleRequest())
294 response = await call
296 self.assertTrue(call.done())
297 self.assertFalse(call.cancelled())
298 self.assertEqual(type(response), messages_pb2.SimpleResponse)
299 self.assertEqual(await call.code(), grpc.StatusCode.OK)
300 self.assertEqual(await call.details(), '')
301 self.assertEqual(await call.initial_metadata(), ())
302 self.assertEqual(await call.trailing_metadata(), ())
304 async def test_call_ok_awaited(self):
306 class Interceptor(aio.UnaryUnaryClientInterceptor):
308 async def intercept_unary_unary(self, continuation,
309 client_call_details, request):
310 call = await continuation(client_call_details, request)
314 async with aio.insecure_channel(self._server_target,
315 interceptors=[Interceptor()
318 multicallable = channel.unary_unary(
319 '/grpc.testing.TestService/UnaryCall',
320 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
321 response_deserializer=messages_pb2.SimpleResponse.FromString)
322 call = multicallable(messages_pb2.SimpleRequest())
323 response = await call
325 self.assertTrue(call.done())
326 self.assertFalse(call.cancelled())
327 self.assertEqual(type(response), messages_pb2.SimpleResponse)
328 self.assertEqual(await call.code(), grpc.StatusCode.OK)
329 self.assertEqual(await call.details(), '')
330 self.assertEqual(await call.initial_metadata(), ())
331 self.assertEqual(await call.trailing_metadata(), ())
333 async def test_call_rpc_error(self):
335 class Interceptor(aio.UnaryUnaryClientInterceptor):
337 async def intercept_unary_unary(self, continuation,
338 client_call_details, request):
339 call = await continuation(client_call_details, request)
342 async with aio.insecure_channel(self._server_target,
343 interceptors=[Interceptor()
346 multicallable = channel.unary_unary(
347 '/grpc.testing.TestService/UnaryCallWithSleep',
348 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
349 response_deserializer=messages_pb2.SimpleResponse.FromString)
351 call = multicallable(
352 messages_pb2.SimpleRequest(),
353 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
355 with self.assertRaises(aio.AioRpcError) as exception_context:
358 self.assertTrue(call.done())
359 self.assertFalse(call.cancelled())
360 self.assertEqual(await call.code(),
361 grpc.StatusCode.DEADLINE_EXCEEDED)
362 self.assertEqual(await call.details(), 'Deadline Exceeded')
363 self.assertEqual(await call.initial_metadata(), ())
364 self.assertEqual(await call.trailing_metadata(), ())
366 async def test_call_rpc_error_awaited(self):
368 class Interceptor(aio.UnaryUnaryClientInterceptor):
370 async def intercept_unary_unary(self, continuation,
371 client_call_details, request):
372 call = await continuation(client_call_details, request)
376 async with aio.insecure_channel(self._server_target,
377 interceptors=[Interceptor()
380 multicallable = channel.unary_unary(
381 '/grpc.testing.TestService/UnaryCallWithSleep',
382 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
383 response_deserializer=messages_pb2.SimpleResponse.FromString)
385 call = multicallable(
386 messages_pb2.SimpleRequest(),
387 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
389 with self.assertRaises(aio.AioRpcError) as exception_context:
392 self.assertTrue(call.done())
393 self.assertFalse(call.cancelled())
394 self.assertEqual(await call.code(),
395 grpc.StatusCode.DEADLINE_EXCEEDED)
396 self.assertEqual(await call.details(), 'Deadline Exceeded')
397 self.assertEqual(await call.initial_metadata(), ())
398 self.assertEqual(await call.trailing_metadata(), ())
400 async def test_cancel_before_rpc(self):
402 interceptor_reached = asyncio.Event()
403 wait_for_ever = self.loop.create_future()
405 class Interceptor(aio.UnaryUnaryClientInterceptor):
407 async def intercept_unary_unary(self, continuation,
408 client_call_details, request):
409 interceptor_reached.set()
412 async with aio.insecure_channel(self._server_target,
413 interceptors=[Interceptor()
416 multicallable = channel.unary_unary(
417 '/grpc.testing.TestService/UnaryCall',
418 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
419 response_deserializer=messages_pb2.SimpleResponse.FromString)
420 call = multicallable(messages_pb2.SimpleRequest())
422 self.assertFalse(call.cancelled())
423 self.assertFalse(call.done())
425 await interceptor_reached.wait()
426 self.assertTrue(call.cancel())
428 with self.assertRaises(asyncio.CancelledError):
431 self.assertTrue(call.cancelled())
432 self.assertTrue(call.done())
433 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
434 self.assertEqual(await call.details(),
435 _LOCAL_CANCEL_DETAILS_EXPECTATION)
436 self.assertEqual(await call.initial_metadata(), None)
437 self.assertEqual(await call.trailing_metadata(), None)
439 async def test_cancel_after_rpc(self):
441 interceptor_reached = asyncio.Event()
442 wait_for_ever = self.loop.create_future()
444 class Interceptor(aio.UnaryUnaryClientInterceptor):
446 async def intercept_unary_unary(self, continuation,
447 client_call_details, request):
448 call = await continuation(client_call_details, request)
450 interceptor_reached.set()
453 async with aio.insecure_channel(self._server_target,
454 interceptors=[Interceptor()
457 multicallable = channel.unary_unary(
458 '/grpc.testing.TestService/UnaryCall',
459 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
460 response_deserializer=messages_pb2.SimpleResponse.FromString)
461 call = multicallable(messages_pb2.SimpleRequest())
463 self.assertFalse(call.cancelled())
464 self.assertFalse(call.done())
466 await interceptor_reached.wait()
467 self.assertTrue(call.cancel())
469 with self.assertRaises(asyncio.CancelledError):
472 self.assertTrue(call.cancelled())
473 self.assertTrue(call.done())
474 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
475 self.assertEqual(await call.details(),
476 _LOCAL_CANCEL_DETAILS_EXPECTATION)
477 self.assertEqual(await call.initial_metadata(), None)
478 self.assertEqual(await call.trailing_metadata(), None)
480 async def test_cancel_inside_interceptor_after_rpc_awaiting(self):
482 class Interceptor(aio.UnaryUnaryClientInterceptor):
484 async def intercept_unary_unary(self, continuation,
485 client_call_details, request):
486 call = await continuation(client_call_details, request)
491 async with aio.insecure_channel(self._server_target,
492 interceptors=[Interceptor()
495 multicallable = channel.unary_unary(
496 '/grpc.testing.TestService/UnaryCall',
497 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
498 response_deserializer=messages_pb2.SimpleResponse.FromString)
499 call = multicallable(messages_pb2.SimpleRequest())
501 with self.assertRaises(asyncio.CancelledError):
504 self.assertTrue(call.cancelled())
505 self.assertTrue(call.done())
506 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
507 self.assertEqual(await call.details(),
508 _LOCAL_CANCEL_DETAILS_EXPECTATION)
509 self.assertEqual(await call.initial_metadata(), None)
510 self.assertEqual(await call.trailing_metadata(), None)
512 async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self):
514 class Interceptor(aio.UnaryUnaryClientInterceptor):
516 async def intercept_unary_unary(self, continuation,
517 client_call_details, request):
518 call = await continuation(client_call_details, request)
522 async with aio.insecure_channel(self._server_target,
523 interceptors=[Interceptor()
526 multicallable = channel.unary_unary(
527 '/grpc.testing.TestService/UnaryCall',
528 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
529 response_deserializer=messages_pb2.SimpleResponse.FromString)
530 call = multicallable(messages_pb2.SimpleRequest())
532 with self.assertRaises(asyncio.CancelledError):
535 self.assertTrue(call.cancelled())
536 self.assertTrue(call.done())
537 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
538 self.assertEqual(await call.details(),
539 _LOCAL_CANCEL_DETAILS_EXPECTATION)
540 self.assertEqual(await call.initial_metadata(), tuple())
541 self.assertEqual(await call.trailing_metadata(), None)
543 async def test_initial_metadata_modification(self):
545 class Interceptor(aio.UnaryUnaryClientInterceptor):
547 async def intercept_unary_unary(self, continuation,
548 client_call_details, request):
549 new_details = aio.ClientCallDetails(
550 method=client_call_details.method,
551 timeout=client_call_details.timeout,
552 metadata=client_call_details.metadata +
553 _INITIAL_METADATA_TO_INJECT,
554 credentials=client_call_details.credentials,
556 return await continuation(new_details, request)
558 async with aio.insecure_channel(self._server_target,
559 interceptors=[Interceptor()
561 stub = test_pb2_grpc.TestServiceStub(channel)
562 call = stub.UnaryCall(messages_pb2.SimpleRequest())
564 # Expected to see the echoed initial metadata
566 _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await
567 call.initial_metadata()))
569 # Expected to see the echoed trailing metadata
571 _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await
572 call.trailing_metadata()))
574 self.assertEqual(await call.code(), grpc.StatusCode.OK)
577 if __name__ == '__main__':
578 logging.basicConfig()
579 unittest.main(verbosity=2)