1 # Copyright 2015 gRPC authors.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
20 from grpc._cython import cygrpc
21 from tests.unit._cython import test_utilities
22 from tests.unit import test_common
23 from tests.unit import resources
25 _SSL_HOST_OVERRIDE = b'foo.test.google.fr'
26 _CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
27 _CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
31 def _metadata_plugin(context, callback):
33 _CALL_CREDENTIALS_METADATA_KEY,
34 _CALL_CREDENTIALS_METADATA_VALUE,
35 ),), cygrpc.StatusCode.ok, b'')
38 class TypeSmokeTest(unittest.TestCase):
40 def testCompletionQueueUpDown(self):
41 completion_queue = cygrpc.CompletionQueue()
44 def testServerUpDown(self):
45 server = cygrpc.Server(set([
53 def testChannelUpDown(self):
54 channel = cygrpc.Channel(b'[::]:0', None, None)
55 channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!')
57 def test_metadata_plugin_call_credentials_up_down(self):
58 cygrpc.MetadataPluginCallCredentials(_metadata_plugin,
61 def testServerStartNoExplicitShutdown(self):
62 server = cygrpc.Server([
68 completion_queue = cygrpc.CompletionQueue()
69 server.register_completion_queue(completion_queue)
70 port = server.add_http2_port(b'[::]:0')
71 self.assertIsInstance(port, int)
75 def testServerStartShutdown(self):
76 completion_queue = cygrpc.CompletionQueue()
77 server = cygrpc.Server([
83 server.add_http2_port(b'[::]:0')
84 server.register_completion_queue(completion_queue)
86 shutdown_tag = object()
87 server.shutdown(completion_queue, shutdown_tag)
88 event = completion_queue.poll()
89 self.assertEqual(cygrpc.CompletionType.operation_complete,
90 event.completion_type)
91 self.assertIs(shutdown_tag, event.tag)
96 class ServerClientMixin(object):
98 def setUpMixin(self, server_credentials, client_credentials, host_override):
99 self.server_completion_queue = cygrpc.CompletionQueue()
100 self.server = cygrpc.Server([
102 b'grpc.so_reuseport',
106 self.server.register_completion_queue(self.server_completion_queue)
107 if server_credentials:
108 self.port = self.server.add_http2_port(b'[::]:0',
111 self.port = self.server.add_http2_port(b'[::]:0')
113 self.client_completion_queue = cygrpc.CompletionQueue()
114 if client_credentials:
115 client_channel_arguments = ((
116 cygrpc.ChannelArgKey.ssl_target_name_override,
119 self.client_channel = cygrpc.Channel(
120 'localhost:{}'.format(self.port).encode(),
121 client_channel_arguments, client_credentials)
123 self.client_channel = cygrpc.Channel(
124 'localhost:{}'.format(self.port).encode(), set(), None)
126 self.host_argument = None # default host
127 self.expected_host = host_override
129 # arbitrary host name necessitating no further identification
130 self.host_argument = b'hostess'
131 self.expected_host = self.host_argument
133 def tearDownMixin(self):
134 self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!')
135 del self.client_channel
137 del self.client_completion_queue
138 del self.server_completion_queue
140 def _perform_queue_operations(self, operations, call, queue, deadline,
142 """Perform the operations with given call, queue, and deadline.
144 Invocation errors are reported with as an exception with `description`
145 in the message. Performs the operations asynchronously, returning a
152 call_result = call.start_client_batch(operations, tag)
153 self.assertEqual(cygrpc.CallError.ok, call_result)
154 event = queue.poll(deadline=deadline)
155 self.assertEqual(cygrpc.CompletionType.operation_complete,
156 event.completion_type)
157 self.assertTrue(event.success)
158 self.assertIs(tag, event.tag)
159 except Exception as error:
160 raise Exception("Error in '{}': {}".format(
161 description, error.message))
164 return test_utilities.SimpleFuture(performer)
167 DEADLINE = time.time() + 5
168 DEADLINE_TOLERANCE = 0.25
169 CLIENT_METADATA_ASCII_KEY = 'key'
170 CLIENT_METADATA_ASCII_VALUE = 'val'
171 CLIENT_METADATA_BIN_KEY = 'key-bin'
172 CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
173 SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
174 SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
175 SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
176 SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
177 SERVER_STATUS_CODE = cygrpc.StatusCode.ok
178 SERVER_STATUS_DETAILS = 'our work is never over'
179 REQUEST = b'in death a member of project mayhem has a name'
180 RESPONSE = b'his name is robert paulson'
183 server_request_tag = object()
184 request_call_result = self.server.request_call(
185 self.server_completion_queue, self.server_completion_queue,
188 self.assertEqual(cygrpc.CallError.ok, request_call_result)
190 client_call_tag = object()
191 client_initial_metadata = (
193 CLIENT_METADATA_ASCII_KEY,
194 CLIENT_METADATA_ASCII_VALUE,
197 CLIENT_METADATA_BIN_KEY,
198 CLIENT_METADATA_BIN_VALUE,
201 client_call = self.client_channel.integrated_call(
202 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata,
206 cygrpc.SendInitialMetadataOperation(
207 client_initial_metadata, _EMPTY_FLAGS),
208 cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
209 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
210 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
211 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
212 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
217 client_event_future = test_utilities.SimpleFuture(
218 self.client_channel.next_call_event)
220 request_event = self.server_completion_queue.poll(deadline=DEADLINE)
221 self.assertEqual(cygrpc.CompletionType.operation_complete,
222 request_event.completion_type)
223 self.assertIsInstance(request_event.call, cygrpc.Call)
224 self.assertIs(server_request_tag, request_event.tag)
226 test_common.metadata_transmitted(client_initial_metadata,
227 request_event.invocation_metadata))
228 self.assertEqual(METHOD, request_event.call_details.method)
229 self.assertEqual(self.expected_host, request_event.call_details.host)
230 self.assertLess(abs(DEADLINE - request_event.call_details.deadline),
233 server_call_tag = object()
234 server_call = request_event.call
235 server_initial_metadata = ((
236 SERVER_INITIAL_METADATA_KEY,
237 SERVER_INITIAL_METADATA_VALUE,
239 server_trailing_metadata = ((
240 SERVER_TRAILING_METADATA_KEY,
241 SERVER_TRAILING_METADATA_VALUE,
243 server_start_batch_result = server_call.start_server_batch([
244 cygrpc.SendInitialMetadataOperation(server_initial_metadata,
246 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
247 cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS),
248 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
249 cygrpc.SendStatusFromServerOperation(
250 server_trailing_metadata, SERVER_STATUS_CODE,
251 SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
253 self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
255 server_event = self.server_completion_queue.poll(deadline=DEADLINE)
256 client_event = client_event_future.result()
258 self.assertEqual(6, len(client_event.batch_operations))
259 found_client_op_types = set()
260 for client_result in client_event.batch_operations:
261 # we expect each op type to be unique
262 self.assertNotIn(client_result.type(), found_client_op_types)
263 found_client_op_types.add(client_result.type())
264 if client_result.type(
265 ) == cygrpc.OperationType.receive_initial_metadata:
267 test_common.metadata_transmitted(
268 server_initial_metadata,
269 client_result.initial_metadata()))
270 elif client_result.type() == cygrpc.OperationType.receive_message:
271 self.assertEqual(RESPONSE, client_result.message())
272 elif client_result.type(
273 ) == cygrpc.OperationType.receive_status_on_client:
275 test_common.metadata_transmitted(
276 server_trailing_metadata,
277 client_result.trailing_metadata()))
278 self.assertEqual(SERVER_STATUS_DETAILS, client_result.details())
279 self.assertEqual(SERVER_STATUS_CODE, client_result.code())
282 cygrpc.OperationType.send_initial_metadata,
283 cygrpc.OperationType.send_message,
284 cygrpc.OperationType.send_close_from_client,
285 cygrpc.OperationType.receive_initial_metadata,
286 cygrpc.OperationType.receive_message,
287 cygrpc.OperationType.receive_status_on_client
288 ]), found_client_op_types)
290 self.assertEqual(5, len(server_event.batch_operations))
291 found_server_op_types = set()
292 for server_result in server_event.batch_operations:
293 self.assertNotIn(server_result.type(), found_server_op_types)
294 found_server_op_types.add(server_result.type())
295 if server_result.type() == cygrpc.OperationType.receive_message:
296 self.assertEqual(REQUEST, server_result.message())
297 elif server_result.type(
298 ) == cygrpc.OperationType.receive_close_on_server:
299 self.assertFalse(server_result.cancelled())
302 cygrpc.OperationType.send_initial_metadata,
303 cygrpc.OperationType.receive_message,
304 cygrpc.OperationType.send_message,
305 cygrpc.OperationType.receive_close_on_server,
306 cygrpc.OperationType.send_status_from_server
307 ]), found_server_op_types)
313 DEADLINE = time.time() + 5
314 DEADLINE_TOLERANCE = 0.25
320 server_request_tag = object()
321 self.server.request_call(self.server_completion_queue,
322 self.server_completion_queue,
324 client_call = self.client_channel.segregated_call(
325 0, METHOD, self.host_argument, DEADLINE, None, None,
328 cygrpc.SendInitialMetadataOperation(empty_metadata,
330 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
336 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
341 client_initial_metadata_event_future = test_utilities.SimpleFuture(
342 client_call.next_event)
344 request_event = self.server_completion_queue.poll(deadline=DEADLINE)
345 server_call = request_event.call
347 def perform_server_operations(operations, description):
348 return self._perform_queue_operations(operations, server_call,
349 self.server_completion_queue,
350 DEADLINE, description)
352 server_event_future = perform_server_operations([
353 cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
354 ], "Server prologue")
356 client_initial_metadata_event_future.result() # force completion
357 server_event_future.result()
361 client_call.operate([
362 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
363 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
365 client_message_event_future = test_utilities.SimpleFuture(
366 client_call.next_event)
367 server_event_future = perform_server_operations([
368 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
369 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
372 client_message_event_future.result() # force completion
373 server_event_future.result()
376 client_call.operate([
377 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
378 ], "Client epilogue")
379 # One for ReceiveStatusOnClient, one for SendCloseFromClient.
380 client_events_future = test_utilities.SimpleFuture(lambda: {
381 client_call.next_event(),
382 client_call.next_event(),
385 server_event_future = perform_server_operations([
386 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
387 cygrpc.SendStatusFromServerOperation(
388 empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
389 ], "Server epilogue")
391 client_events_future.result() # force completion
392 server_event_future.result()
395 class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
398 self.setUpMixin(None, None, None)
404 class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
407 server_credentials = cygrpc.server_credentials_ssl(
409 cygrpc.SslPemKeyCertPair(resources.private_key(),
410 resources.certificate_chain())
412 client_credentials = cygrpc.SSLChannelCredentials(
413 resources.test_root_certificates(), None, None)
414 self.setUpMixin(server_credentials, client_credentials,
421 if __name__ == '__main__':
422 unittest.main(verbosity=2)