Imported Upstream version 1.27.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / unit / _cython / cygrpc_test.py
1 # Copyright 2015 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import time
16 import threading
17 import unittest
18 import platform
19
20 from grpc._cython import cygrpc
21 from tests.unit._cython import test_utilities
22 from tests.unit import test_common
23 from tests.unit import resources
24
25 _SSL_HOST_OVERRIDE = b'foo.test.google.fr'
26 _CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
27 _CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
28 _EMPTY_FLAGS = 0
29
30
31 def _metadata_plugin(context, callback):
32     callback(((
33         _CALL_CREDENTIALS_METADATA_KEY,
34         _CALL_CREDENTIALS_METADATA_VALUE,
35     ),), cygrpc.StatusCode.ok, b'')
36
37
38 class TypeSmokeTest(unittest.TestCase):
39
40     def testCompletionQueueUpDown(self):
41         completion_queue = cygrpc.CompletionQueue()
42         del completion_queue
43
44     def testServerUpDown(self):
45         server = cygrpc.Server(set([
46             (
47                 b'grpc.so_reuseport',
48                 0,
49             ),
50         ]))
51         del server
52
53     def testChannelUpDown(self):
54         channel = cygrpc.Channel(b'[::]:0', None, None)
55         channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!')
56
57     def test_metadata_plugin_call_credentials_up_down(self):
58         cygrpc.MetadataPluginCallCredentials(_metadata_plugin,
59                                              b'test plugin name!')
60
61     def testServerStartNoExplicitShutdown(self):
62         server = cygrpc.Server([
63             (
64                 b'grpc.so_reuseport',
65                 0,
66             ),
67         ])
68         completion_queue = cygrpc.CompletionQueue()
69         server.register_completion_queue(completion_queue)
70         port = server.add_http2_port(b'[::]:0')
71         self.assertIsInstance(port, int)
72         server.start()
73         del server
74
75     def testServerStartShutdown(self):
76         completion_queue = cygrpc.CompletionQueue()
77         server = cygrpc.Server([
78             (
79                 b'grpc.so_reuseport',
80                 0,
81             ),
82         ])
83         server.add_http2_port(b'[::]:0')
84         server.register_completion_queue(completion_queue)
85         server.start()
86         shutdown_tag = object()
87         server.shutdown(completion_queue, shutdown_tag)
88         event = completion_queue.poll()
89         self.assertEqual(cygrpc.CompletionType.operation_complete,
90                          event.completion_type)
91         self.assertIs(shutdown_tag, event.tag)
92         del server
93         del completion_queue
94
95
96 class ServerClientMixin(object):
97
98     def setUpMixin(self, server_credentials, client_credentials, host_override):
99         self.server_completion_queue = cygrpc.CompletionQueue()
100         self.server = cygrpc.Server([
101             (
102                 b'grpc.so_reuseport',
103                 0,
104             ),
105         ])
106         self.server.register_completion_queue(self.server_completion_queue)
107         if server_credentials:
108             self.port = self.server.add_http2_port(b'[::]:0',
109                                                    server_credentials)
110         else:
111             self.port = self.server.add_http2_port(b'[::]:0')
112         self.server.start()
113         self.client_completion_queue = cygrpc.CompletionQueue()
114         if client_credentials:
115             client_channel_arguments = ((
116                 cygrpc.ChannelArgKey.ssl_target_name_override,
117                 host_override,
118             ),)
119             self.client_channel = cygrpc.Channel(
120                 'localhost:{}'.format(self.port).encode(),
121                 client_channel_arguments, client_credentials)
122         else:
123             self.client_channel = cygrpc.Channel(
124                 'localhost:{}'.format(self.port).encode(), set(), None)
125         if host_override:
126             self.host_argument = None  # default host
127             self.expected_host = host_override
128         else:
129             # arbitrary host name necessitating no further identification
130             self.host_argument = b'hostess'
131             self.expected_host = self.host_argument
132
133     def tearDownMixin(self):
134         self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!')
135         del self.client_channel
136         del self.server
137         del self.client_completion_queue
138         del self.server_completion_queue
139
140     def _perform_queue_operations(self, operations, call, queue, deadline,
141                                   description):
142         """Perform the operations with given call, queue, and deadline.
143
144         Invocation errors are reported with as an exception with `description`
145         in the message. Performs the operations asynchronously, returning a
146         future.
147         """
148
149         def performer():
150             tag = object()
151             try:
152                 call_result = call.start_client_batch(operations, tag)
153                 self.assertEqual(cygrpc.CallError.ok, call_result)
154                 event = queue.poll(deadline=deadline)
155                 self.assertEqual(cygrpc.CompletionType.operation_complete,
156                                  event.completion_type)
157                 self.assertTrue(event.success)
158                 self.assertIs(tag, event.tag)
159             except Exception as error:
160                 raise Exception("Error in '{}': {}".format(
161                     description, error.message))
162             return event
163
164         return test_utilities.SimpleFuture(performer)
165
166     def test_echo(self):
167         DEADLINE = time.time() + 5
168         DEADLINE_TOLERANCE = 0.25
169         CLIENT_METADATA_ASCII_KEY = 'key'
170         CLIENT_METADATA_ASCII_VALUE = 'val'
171         CLIENT_METADATA_BIN_KEY = 'key-bin'
172         CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
173         SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
174         SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
175         SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
176         SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
177         SERVER_STATUS_CODE = cygrpc.StatusCode.ok
178         SERVER_STATUS_DETAILS = 'our work is never over'
179         REQUEST = b'in death a member of project mayhem has a name'
180         RESPONSE = b'his name is robert paulson'
181         METHOD = b'twinkies'
182
183         server_request_tag = object()
184         request_call_result = self.server.request_call(
185             self.server_completion_queue, self.server_completion_queue,
186             server_request_tag)
187
188         self.assertEqual(cygrpc.CallError.ok, request_call_result)
189
190         client_call_tag = object()
191         client_initial_metadata = (
192             (
193                 CLIENT_METADATA_ASCII_KEY,
194                 CLIENT_METADATA_ASCII_VALUE,
195             ),
196             (
197                 CLIENT_METADATA_BIN_KEY,
198                 CLIENT_METADATA_BIN_VALUE,
199             ),
200         )
201         client_call = self.client_channel.integrated_call(
202             0, METHOD, self.host_argument, DEADLINE, client_initial_metadata,
203             None, [
204                 (
205                     [
206                         cygrpc.SendInitialMetadataOperation(
207                             client_initial_metadata, _EMPTY_FLAGS),
208                         cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
209                         cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
210                         cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
211                         cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
212                         cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
213                     ],
214                     client_call_tag,
215                 ),
216             ])
217         client_event_future = test_utilities.SimpleFuture(
218             self.client_channel.next_call_event)
219
220         request_event = self.server_completion_queue.poll(deadline=DEADLINE)
221         self.assertEqual(cygrpc.CompletionType.operation_complete,
222                          request_event.completion_type)
223         self.assertIsInstance(request_event.call, cygrpc.Call)
224         self.assertIs(server_request_tag, request_event.tag)
225         self.assertTrue(
226             test_common.metadata_transmitted(client_initial_metadata,
227                                              request_event.invocation_metadata))
228         self.assertEqual(METHOD, request_event.call_details.method)
229         self.assertEqual(self.expected_host, request_event.call_details.host)
230         self.assertLess(abs(DEADLINE - request_event.call_details.deadline),
231                         DEADLINE_TOLERANCE)
232
233         server_call_tag = object()
234         server_call = request_event.call
235         server_initial_metadata = ((
236             SERVER_INITIAL_METADATA_KEY,
237             SERVER_INITIAL_METADATA_VALUE,
238         ),)
239         server_trailing_metadata = ((
240             SERVER_TRAILING_METADATA_KEY,
241             SERVER_TRAILING_METADATA_VALUE,
242         ),)
243         server_start_batch_result = server_call.start_server_batch([
244             cygrpc.SendInitialMetadataOperation(server_initial_metadata,
245                                                 _EMPTY_FLAGS),
246             cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
247             cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS),
248             cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
249             cygrpc.SendStatusFromServerOperation(
250                 server_trailing_metadata, SERVER_STATUS_CODE,
251                 SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
252         ], server_call_tag)
253         self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
254
255         server_event = self.server_completion_queue.poll(deadline=DEADLINE)
256         client_event = client_event_future.result()
257
258         self.assertEqual(6, len(client_event.batch_operations))
259         found_client_op_types = set()
260         for client_result in client_event.batch_operations:
261             # we expect each op type to be unique
262             self.assertNotIn(client_result.type(), found_client_op_types)
263             found_client_op_types.add(client_result.type())
264             if client_result.type(
265             ) == cygrpc.OperationType.receive_initial_metadata:
266                 self.assertTrue(
267                     test_common.metadata_transmitted(
268                         server_initial_metadata,
269                         client_result.initial_metadata()))
270             elif client_result.type() == cygrpc.OperationType.receive_message:
271                 self.assertEqual(RESPONSE, client_result.message())
272             elif client_result.type(
273             ) == cygrpc.OperationType.receive_status_on_client:
274                 self.assertTrue(
275                     test_common.metadata_transmitted(
276                         server_trailing_metadata,
277                         client_result.trailing_metadata()))
278                 self.assertEqual(SERVER_STATUS_DETAILS, client_result.details())
279                 self.assertEqual(SERVER_STATUS_CODE, client_result.code())
280         self.assertEqual(
281             set([
282                 cygrpc.OperationType.send_initial_metadata,
283                 cygrpc.OperationType.send_message,
284                 cygrpc.OperationType.send_close_from_client,
285                 cygrpc.OperationType.receive_initial_metadata,
286                 cygrpc.OperationType.receive_message,
287                 cygrpc.OperationType.receive_status_on_client
288             ]), found_client_op_types)
289
290         self.assertEqual(5, len(server_event.batch_operations))
291         found_server_op_types = set()
292         for server_result in server_event.batch_operations:
293             self.assertNotIn(server_result.type(), found_server_op_types)
294             found_server_op_types.add(server_result.type())
295             if server_result.type() == cygrpc.OperationType.receive_message:
296                 self.assertEqual(REQUEST, server_result.message())
297             elif server_result.type(
298             ) == cygrpc.OperationType.receive_close_on_server:
299                 self.assertFalse(server_result.cancelled())
300         self.assertEqual(
301             set([
302                 cygrpc.OperationType.send_initial_metadata,
303                 cygrpc.OperationType.receive_message,
304                 cygrpc.OperationType.send_message,
305                 cygrpc.OperationType.receive_close_on_server,
306                 cygrpc.OperationType.send_status_from_server
307             ]), found_server_op_types)
308
309         del client_call
310         del server_call
311
312     def test_6522(self):
313         DEADLINE = time.time() + 5
314         DEADLINE_TOLERANCE = 0.25
315         METHOD = b'twinkies'
316
317         empty_metadata = ()
318
319         # Prologue
320         server_request_tag = object()
321         self.server.request_call(self.server_completion_queue,
322                                  self.server_completion_queue,
323                                  server_request_tag)
324         client_call = self.client_channel.segregated_call(
325             0, METHOD, self.host_argument, DEADLINE, None, None,
326             ([(
327                 [
328                     cygrpc.SendInitialMetadataOperation(empty_metadata,
329                                                         _EMPTY_FLAGS),
330                     cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
331                 ],
332                 object(),
333             ),
334               (
335                   [
336                       cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
337                   ],
338                   object(),
339               )]))
340
341         client_initial_metadata_event_future = test_utilities.SimpleFuture(
342             client_call.next_event)
343
344         request_event = self.server_completion_queue.poll(deadline=DEADLINE)
345         server_call = request_event.call
346
347         def perform_server_operations(operations, description):
348             return self._perform_queue_operations(operations, server_call,
349                                                   self.server_completion_queue,
350                                                   DEADLINE, description)
351
352         server_event_future = perform_server_operations([
353             cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
354         ], "Server prologue")
355
356         client_initial_metadata_event_future.result()  # force completion
357         server_event_future.result()
358
359         # Messaging
360         for _ in range(10):
361             client_call.operate([
362                 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
363                 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
364             ], "Client message")
365             client_message_event_future = test_utilities.SimpleFuture(
366                 client_call.next_event)
367             server_event_future = perform_server_operations([
368                 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
369                 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
370             ], "Server receive")
371
372             client_message_event_future.result()  # force completion
373             server_event_future.result()
374
375         # Epilogue
376         client_call.operate([
377             cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
378         ], "Client epilogue")
379         # One for ReceiveStatusOnClient, one for SendCloseFromClient.
380         client_events_future = test_utilities.SimpleFuture(lambda: {
381             client_call.next_event(),
382             client_call.next_event(),
383         })
384
385         server_event_future = perform_server_operations([
386             cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
387             cygrpc.SendStatusFromServerOperation(
388                 empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
389         ], "Server epilogue")
390
391         client_events_future.result()  # force completion
392         server_event_future.result()
393
394
395 class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
396
397     def setUp(self):
398         self.setUpMixin(None, None, None)
399
400     def tearDown(self):
401         self.tearDownMixin()
402
403
404 class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
405
406     def setUp(self):
407         server_credentials = cygrpc.server_credentials_ssl(
408             None, [
409                 cygrpc.SslPemKeyCertPair(resources.private_key(),
410                                          resources.certificate_chain())
411             ], False)
412         client_credentials = cygrpc.SSLChannelCredentials(
413             resources.test_root_certificates(), None, None)
414         self.setUpMixin(server_credentials, client_credentials,
415                         _SSL_HOST_OVERRIDE)
416
417     def tearDown(self):
418         self.tearDownMixin()
419
420
421 if __name__ == '__main__':
422     unittest.main(verbosity=2)