Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests_aio / interop / methods.py
1 # Copyright 2019 The gRPC Authors
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Implementations of interoperability test methods."""
15
16 import argparse
17 import asyncio
18 import collections
19 import datetime
20 import enum
21 import inspect
22 import json
23 import os
24 import threading
25 import time
26 from typing import Any, Optional, Union
27
28 import grpc
29 from google import auth as google_auth
30 from google.auth import environment_vars as google_auth_environment_vars
31 from google.auth.transport import grpc as google_auth_transport_grpc
32 from google.auth.transport import requests as google_auth_transport_requests
33 from grpc.experimental import aio
34
35 from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc
36
37 _INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
38 _TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
39
40
41 async def _expect_status_code(call: aio.Call,
42                               expected_code: grpc.StatusCode) -> None:
43     code = await call.code()
44     if code != expected_code:
45         raise ValueError('expected code %s, got %s' %
46                          (expected_code, await call.code()))
47
48
49 async def _expect_status_details(call: aio.Call, expected_details: str) -> None:
50     details = await call.details()
51     if details != expected_details:
52         raise ValueError('expected message %s, got %s' %
53                          (expected_details, await call.details()))
54
55
56 async def _validate_status_code_and_details(call: aio.Call,
57                                             expected_code: grpc.StatusCode,
58                                             expected_details: str) -> None:
59     await _expect_status_code(call, expected_code)
60     await _expect_status_details(call, expected_details)
61
62
63 def _validate_payload_type_and_length(response: Union[
64     messages_pb2.SimpleResponse, messages_pb2.StreamingOutputCallResponse],
65                                       expected_type: Any,
66                                       expected_length: int) -> None:
67     if response.payload.type is not expected_type:
68         raise ValueError('expected payload type %s, got %s' %
69                          (expected_type, type(response.payload.type)))
70     elif len(response.payload.body) != expected_length:
71         raise ValueError('expected payload body size %d, got %d' %
72                          (expected_length, len(response.payload.body)))
73
74
75 async def _large_unary_common_behavior(
76     stub: test_pb2_grpc.TestServiceStub, fill_username: bool,
77     fill_oauth_scope: bool, call_credentials: Optional[grpc.CallCredentials]
78 ) -> messages_pb2.SimpleResponse:
79     size = 314159
80     request = messages_pb2.SimpleRequest(
81         response_type=messages_pb2.COMPRESSABLE,
82         response_size=size,
83         payload=messages_pb2.Payload(body=b'\x00' * 271828),
84         fill_username=fill_username,
85         fill_oauth_scope=fill_oauth_scope)
86     response = await stub.UnaryCall(request, credentials=call_credentials)
87     _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
88     return response
89
90
91 async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None:
92     response = await stub.EmptyCall(empty_pb2.Empty())
93     if not isinstance(response, empty_pb2.Empty):
94         raise TypeError('response is of type "%s", not empty_pb2.Empty!' %
95                         type(response))
96
97
98 async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None:
99     await _large_unary_common_behavior(stub, False, False, None)
100
101
102 async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None:
103     payload_body_sizes = (
104         27182,
105         8,
106         1828,
107         45904,
108     )
109
110     async def request_gen():
111         for size in payload_body_sizes:
112             yield messages_pb2.StreamingInputCallRequest(
113                 payload=messages_pb2.Payload(body=b'\x00' * size))
114
115     response = await stub.StreamingInputCall(request_gen())
116     if response.aggregated_payload_size != sum(payload_body_sizes):
117         raise ValueError('incorrect size %d!' %
118                          response.aggregated_payload_size)
119
120
121 async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None:
122     sizes = (
123         31415,
124         9,
125         2653,
126         58979,
127     )
128
129     request = messages_pb2.StreamingOutputCallRequest(
130         response_type=messages_pb2.COMPRESSABLE,
131         response_parameters=(
132             messages_pb2.ResponseParameters(size=sizes[0]),
133             messages_pb2.ResponseParameters(size=sizes[1]),
134             messages_pb2.ResponseParameters(size=sizes[2]),
135             messages_pb2.ResponseParameters(size=sizes[3]),
136         ))
137     call = stub.StreamingOutputCall(request)
138     for size in sizes:
139         response = await call.read()
140         _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
141                                           size)
142
143
144 async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None:
145     request_response_sizes = (
146         31415,
147         9,
148         2653,
149         58979,
150     )
151     request_payload_sizes = (
152         27182,
153         8,
154         1828,
155         45904,
156     )
157
158     call = stub.FullDuplexCall()
159     for response_size, payload_size in zip(request_response_sizes,
160                                            request_payload_sizes):
161         request = messages_pb2.StreamingOutputCallRequest(
162             response_type=messages_pb2.COMPRESSABLE,
163             response_parameters=(messages_pb2.ResponseParameters(
164                 size=response_size),),
165             payload=messages_pb2.Payload(body=b'\x00' * payload_size))
166
167         await call.write(request)
168         response = await call.read()
169         _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
170                                           response_size)
171     await call.done_writing()
172     await _validate_status_code_and_details(call, grpc.StatusCode.OK, '')
173
174
175 async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub):
176     call = stub.StreamingInputCall()
177     call.cancel()
178     if not call.cancelled():
179         raise ValueError('expected cancelled method to return True')
180     code = await call.code()
181     if code is not grpc.StatusCode.CANCELLED:
182         raise ValueError('expected status code CANCELLED')
183
184
185 async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub):
186     request_response_sizes = (
187         31415,
188         9,
189         2653,
190         58979,
191     )
192     request_payload_sizes = (
193         27182,
194         8,
195         1828,
196         45904,
197     )
198
199     call = stub.FullDuplexCall()
200
201     response_size = request_response_sizes[0]
202     payload_size = request_payload_sizes[0]
203     request = messages_pb2.StreamingOutputCallRequest(
204         response_type=messages_pb2.COMPRESSABLE,
205         response_parameters=(messages_pb2.ResponseParameters(
206             size=response_size),),
207         payload=messages_pb2.Payload(body=b'\x00' * payload_size))
208
209     await call.write(request)
210     await call.read()
211
212     call.cancel()
213
214     try:
215         await call.read()
216     except asyncio.CancelledError:
217         assert await call.code() is grpc.StatusCode.CANCELLED
218     else:
219         raise ValueError('expected call to be cancelled')
220
221
222 async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub):
223     request_payload_size = 27182
224     time_limit = datetime.timedelta(seconds=1)
225
226     call = stub.FullDuplexCall(timeout=time_limit.total_seconds())
227
228     request = messages_pb2.StreamingOutputCallRequest(
229         response_type=messages_pb2.COMPRESSABLE,
230         payload=messages_pb2.Payload(body=b'\x00' * request_payload_size),
231         response_parameters=(messages_pb2.ResponseParameters(
232             interval_us=int(time_limit.total_seconds() * 2 * 10**6)),))
233     await call.write(request)
234     await call.done_writing()
235     try:
236         await call.read()
237     except aio.AioRpcError as rpc_error:
238         if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
239             raise
240     else:
241         raise ValueError('expected call to exceed deadline')
242
243
244 async def _empty_stream(stub: test_pb2_grpc.TestServiceStub):
245     call = stub.FullDuplexCall()
246     await call.done_writing()
247     assert await call.read() == aio.EOF
248
249
250 async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub):
251     details = 'test status message'
252     status = grpc.StatusCode.UNKNOWN  # code = 2
253
254     # Test with a UnaryCall
255     request = messages_pb2.SimpleRequest(
256         response_type=messages_pb2.COMPRESSABLE,
257         response_size=1,
258         payload=messages_pb2.Payload(body=b'\x00'),
259         response_status=messages_pb2.EchoStatus(code=status.value[0],
260                                                 message=details))
261     call = stub.UnaryCall(request)
262     await _validate_status_code_and_details(call, status, details)
263
264     # Test with a FullDuplexCall
265     call = stub.FullDuplexCall()
266     request = messages_pb2.StreamingOutputCallRequest(
267         response_type=messages_pb2.COMPRESSABLE,
268         response_parameters=(messages_pb2.ResponseParameters(size=1),),
269         payload=messages_pb2.Payload(body=b'\x00'),
270         response_status=messages_pb2.EchoStatus(code=status.value[0],
271                                                 message=details))
272     await call.write(request)  # sends the initial request.
273     await call.done_writing()
274     try:
275         await call.read()
276     except aio.AioRpcError as rpc_error:
277         assert rpc_error.code() == status
278     await _validate_status_code_and_details(call, status, details)
279
280
281 async def _unimplemented_method(stub: test_pb2_grpc.TestServiceStub):
282     call = stub.UnimplementedCall(empty_pb2.Empty())
283     await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED)
284
285
286 async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub):
287     call = stub.UnimplementedCall(empty_pb2.Empty())
288     await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED)
289
290
291 async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub):
292     initial_metadata_value = "test_initial_metadata_value"
293     trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
294     metadata = aio.Metadata(
295         (_INITIAL_METADATA_KEY, initial_metadata_value),
296         (_TRAILING_METADATA_KEY, trailing_metadata_value),
297     )
298
299     async def _validate_metadata(call):
300         initial_metadata = await call.initial_metadata()
301         if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
302             raise ValueError('expected initial metadata %s, got %s' %
303                              (initial_metadata_value,
304                               initial_metadata[_INITIAL_METADATA_KEY]))
305
306         trailing_metadata = await call.trailing_metadata()
307         if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
308             raise ValueError('expected trailing metadata %s, got %s' %
309                              (trailing_metadata_value,
310                               trailing_metadata[_TRAILING_METADATA_KEY]))
311
312     # Testing with UnaryCall
313     request = messages_pb2.SimpleRequest(
314         response_type=messages_pb2.COMPRESSABLE,
315         response_size=1,
316         payload=messages_pb2.Payload(body=b'\x00'))
317     call = stub.UnaryCall(request, metadata=metadata)
318     await _validate_metadata(call)
319
320     # Testing with FullDuplexCall
321     call = stub.FullDuplexCall(metadata=metadata)
322     request = messages_pb2.StreamingOutputCallRequest(
323         response_type=messages_pb2.COMPRESSABLE,
324         response_parameters=(messages_pb2.ResponseParameters(size=1),))
325     await call.write(request)
326     await call.read()
327     await call.done_writing()
328     await _validate_metadata(call)
329
330
331 async def _compute_engine_creds(stub: test_pb2_grpc.TestServiceStub,
332                                 args: argparse.Namespace):
333     response = await _large_unary_common_behavior(stub, True, True, None)
334     if args.default_service_account != response.username:
335         raise ValueError('expected username %s, got %s' %
336                          (args.default_service_account, response.username))
337
338
339 async def _oauth2_auth_token(stub: test_pb2_grpc.TestServiceStub,
340                              args: argparse.Namespace):
341     json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
342     wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
343     response = await _large_unary_common_behavior(stub, True, True, None)
344     if wanted_email != response.username:
345         raise ValueError('expected username %s, got %s' %
346                          (wanted_email, response.username))
347     if args.oauth_scope.find(response.oauth_scope) == -1:
348         raise ValueError(
349             'expected to find oauth scope "{}" in received "{}"'.format(
350                 response.oauth_scope, args.oauth_scope))
351
352
353 async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub):
354     json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
355     wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
356     response = await _large_unary_common_behavior(stub, True, False, None)
357     if wanted_email != response.username:
358         raise ValueError('expected username %s, got %s' %
359                          (wanted_email, response.username))
360
361
362 async def _per_rpc_creds(stub: test_pb2_grpc.TestServiceStub,
363                          args: argparse.Namespace):
364     json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
365     wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
366     google_credentials, unused_project_id = google_auth.default(
367         scopes=[args.oauth_scope])
368     call_credentials = grpc.metadata_call_credentials(
369         google_auth_transport_grpc.AuthMetadataPlugin(
370             credentials=google_credentials,
371             request=google_auth_transport_requests.Request()))
372     response = await _large_unary_common_behavior(stub, True, False,
373                                                   call_credentials)
374     if wanted_email != response.username:
375         raise ValueError('expected username %s, got %s' %
376                          (wanted_email, response.username))
377
378
379 async def _special_status_message(stub: test_pb2_grpc.TestServiceStub):
380     details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode(
381         'utf-8')
382     status = grpc.StatusCode.UNKNOWN  # code = 2
383
384     # Test with a UnaryCall
385     request = messages_pb2.SimpleRequest(
386         response_type=messages_pb2.COMPRESSABLE,
387         response_size=1,
388         payload=messages_pb2.Payload(body=b'\x00'),
389         response_status=messages_pb2.EchoStatus(code=status.value[0],
390                                                 message=details))
391     call = stub.UnaryCall(request)
392     await _validate_status_code_and_details(call, status, details)
393
394
395 @enum.unique
396 class TestCase(enum.Enum):
397     EMPTY_UNARY = 'empty_unary'
398     LARGE_UNARY = 'large_unary'
399     SERVER_STREAMING = 'server_streaming'
400     CLIENT_STREAMING = 'client_streaming'
401     PING_PONG = 'ping_pong'
402     CANCEL_AFTER_BEGIN = 'cancel_after_begin'
403     CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
404     TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
405     EMPTY_STREAM = 'empty_stream'
406     STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
407     UNIMPLEMENTED_METHOD = 'unimplemented_method'
408     UNIMPLEMENTED_SERVICE = 'unimplemented_service'
409     CUSTOM_METADATA = "custom_metadata"
410     COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
411     OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
412     JWT_TOKEN_CREDS = 'jwt_token_creds'
413     PER_RPC_CREDS = 'per_rpc_creds'
414     SPECIAL_STATUS_MESSAGE = 'special_status_message'
415
416
417 _TEST_CASE_IMPLEMENTATION_MAPPING = {
418     TestCase.EMPTY_UNARY: _empty_unary,
419     TestCase.LARGE_UNARY: _large_unary,
420     TestCase.SERVER_STREAMING: _server_streaming,
421     TestCase.CLIENT_STREAMING: _client_streaming,
422     TestCase.PING_PONG: _ping_pong,
423     TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin,
424     TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response,
425     TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server,
426     TestCase.EMPTY_STREAM: _empty_stream,
427     TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message,
428     TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method,
429     TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service,
430     TestCase.CUSTOM_METADATA: _custom_metadata,
431     TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds,
432     TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token,
433     TestCase.JWT_TOKEN_CREDS: _jwt_token_creds,
434     TestCase.PER_RPC_CREDS: _per_rpc_creds,
435     TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message,
436 }
437
438
439 async def test_interoperability(
440         case: TestCase,
441         stub: test_pb2_grpc.TestServiceStub,
442         args: Optional[argparse.Namespace] = None) -> None:
443     method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case)
444     if method is None:
445         raise NotImplementedError(f'Test case "{case}" not implemented!')
446     else:
447         num_params = len(inspect.signature(method).parameters)
448         if num_params == 1:
449             await method(stub)
450         elif num_params == 2:
451             if args is not None:
452                 await method(stub, args)
453             else:
454                 raise ValueError(f'Failed to run case [{case}]: args is None')
455         else:
456             raise ValueError(f'Invalid number of parameters [{num_params}]')