Imported Upstream version 1.27.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / unit / _compression_test.py
1 # Copyright 2016 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Tests server and client side compression."""
15
16 import unittest
17
18 import contextlib
19 from concurrent import futures
20 import functools
21 import itertools
22 import logging
23 import os
24
25 import grpc
26 from grpc import _grpcio_metadata
27
28 from tests.unit import test_common
29 from tests.unit.framework.common import test_constants
30 from tests.unit import _tcp_proxy
31
32 _UNARY_UNARY = '/test/UnaryUnary'
33 _UNARY_STREAM = '/test/UnaryStream'
34 _STREAM_UNARY = '/test/StreamUnary'
35 _STREAM_STREAM = '/test/StreamStream'
36
37 # Cut down on test time.
38 _STREAM_LENGTH = test_constants.STREAM_LENGTH // 16
39
40 _HOST = 'localhost'
41
42 _REQUEST = b'\x00' * 100
43 _COMPRESSION_RATIO_THRESHOLD = 0.05
44 _COMPRESSION_METHODS = (
45     None,
46     # Disabled for test tractability.
47     # grpc.Compression.NoCompression,
48     # grpc.Compression.Deflate,
49     grpc.Compression.Gzip,
50 )
51 _COMPRESSION_NAMES = {
52     None: 'Uncompressed',
53     grpc.Compression.NoCompression: 'NoCompression',
54     grpc.Compression.Deflate: 'DeflateCompression',
55     grpc.Compression.Gzip: 'GzipCompression',
56 }
57
58 _TEST_OPTIONS = {
59     'client_streaming': (True, False),
60     'server_streaming': (True, False),
61     'channel_compression': _COMPRESSION_METHODS,
62     'multicallable_compression': _COMPRESSION_METHODS,
63     'server_compression': _COMPRESSION_METHODS,
64     'server_call_compression': _COMPRESSION_METHODS,
65 }
66
67
68 def _make_handle_unary_unary(pre_response_callback):
69
70     def _handle_unary(request, servicer_context):
71         if pre_response_callback:
72             pre_response_callback(request, servicer_context)
73         return request
74
75     return _handle_unary
76
77
78 def _make_handle_unary_stream(pre_response_callback):
79
80     def _handle_unary_stream(request, servicer_context):
81         if pre_response_callback:
82             pre_response_callback(request, servicer_context)
83         for _ in range(_STREAM_LENGTH):
84             yield request
85
86     return _handle_unary_stream
87
88
89 def _make_handle_stream_unary(pre_response_callback):
90
91     def _handle_stream_unary(request_iterator, servicer_context):
92         if pre_response_callback:
93             pre_response_callback(request_iterator, servicer_context)
94         response = None
95         for request in request_iterator:
96             if not response:
97                 response = request
98         return response
99
100     return _handle_stream_unary
101
102
103 def _make_handle_stream_stream(pre_response_callback):
104
105     def _handle_stream(request_iterator, servicer_context):
106         # TODO(issue:#6891) We should be able to remove this loop,
107         # and replace with return; yield
108         for request in request_iterator:
109             if pre_response_callback:
110                 pre_response_callback(request, servicer_context)
111             yield request
112
113     return _handle_stream
114
115
116 def set_call_compression(compression_method, request_or_iterator,
117                          servicer_context):
118     del request_or_iterator
119     servicer_context.set_compression(compression_method)
120
121
122 def disable_next_compression(request, servicer_context):
123     del request
124     servicer_context.disable_next_message_compression()
125
126
127 def disable_first_compression(request, servicer_context):
128     if int(request.decode('ascii')) == 0:
129         servicer_context.disable_next_message_compression()
130
131
132 class _MethodHandler(grpc.RpcMethodHandler):
133
134     def __init__(self, request_streaming, response_streaming,
135                  pre_response_callback):
136         self.request_streaming = request_streaming
137         self.response_streaming = response_streaming
138         self.request_deserializer = None
139         self.response_serializer = None
140         self.unary_unary = None
141         self.unary_stream = None
142         self.stream_unary = None
143         self.stream_stream = None
144
145         if self.request_streaming and self.response_streaming:
146             self.stream_stream = _make_handle_stream_stream(
147                 pre_response_callback)
148         elif not self.request_streaming and not self.response_streaming:
149             self.unary_unary = _make_handle_unary_unary(pre_response_callback)
150         elif not self.request_streaming and self.response_streaming:
151             self.unary_stream = _make_handle_unary_stream(pre_response_callback)
152         else:
153             self.stream_unary = _make_handle_stream_unary(pre_response_callback)
154
155
156 class _GenericHandler(grpc.GenericRpcHandler):
157
158     def __init__(self, pre_response_callback):
159         self._pre_response_callback = pre_response_callback
160
161     def service(self, handler_call_details):
162         if handler_call_details.method == _UNARY_UNARY:
163             return _MethodHandler(False, False, self._pre_response_callback)
164         elif handler_call_details.method == _UNARY_STREAM:
165             return _MethodHandler(False, True, self._pre_response_callback)
166         elif handler_call_details.method == _STREAM_UNARY:
167             return _MethodHandler(True, False, self._pre_response_callback)
168         elif handler_call_details.method == _STREAM_STREAM:
169             return _MethodHandler(True, True, self._pre_response_callback)
170         else:
171             return None
172
173
174 @contextlib.contextmanager
175 def _instrumented_client_server_pair(channel_kwargs, server_kwargs,
176                                      server_handler):
177     server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs)
178     server.add_generic_rpc_handlers((server_handler,))
179     server_port = server.add_insecure_port('{}:0'.format(_HOST))
180     server.start()
181     with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy:
182         proxy_port = proxy.get_port()
183         with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port),
184                                    **channel_kwargs) as client_channel:
185             try:
186                 yield client_channel, proxy, server
187             finally:
188                 server.stop(None)
189
190
191 def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function,
192                      server_kwargs, server_handler, message):
193     with _instrumented_client_server_pair(channel_kwargs, server_kwargs,
194                                           server_handler) as pipeline:
195         client_channel, proxy, server = pipeline
196         client_function(client_channel, multicallable_kwargs, message)
197         return proxy.get_byte_count()
198
199
200 def _get_compression_ratios(client_function, first_channel_kwargs,
201                             first_multicallable_kwargs, first_server_kwargs,
202                             first_server_handler, second_channel_kwargs,
203                             second_multicallable_kwargs, second_server_kwargs,
204                             second_server_handler, message):
205     try:
206         # This test requires the byte length of each connection to be deterministic. As
207         # it turns out, flow control puts bytes on the wire in a nondeterministic
208         # manner. We disable it here in order to measure compression ratios
209         # deterministically.
210         os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] = 'true'
211         first_bytes_sent, first_bytes_received = _get_byte_counts(
212             first_channel_kwargs, first_multicallable_kwargs, client_function,
213             first_server_kwargs, first_server_handler, message)
214         second_bytes_sent, second_bytes_received = _get_byte_counts(
215             second_channel_kwargs, second_multicallable_kwargs, client_function,
216             second_server_kwargs, second_server_handler, message)
217         return ((second_bytes_sent - first_bytes_sent) /
218                 float(first_bytes_sent),
219                 (second_bytes_received - first_bytes_received) /
220                 float(first_bytes_received))
221     finally:
222         del os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL']
223
224
225 def _unary_unary_client(channel, multicallable_kwargs, message):
226     multi_callable = channel.unary_unary(_UNARY_UNARY)
227     response = multi_callable(message, **multicallable_kwargs)
228     if response != message:
229         raise RuntimeError("Request '{}' != Response '{}'".format(
230             message, response))
231
232
233 def _unary_stream_client(channel, multicallable_kwargs, message):
234     multi_callable = channel.unary_stream(_UNARY_STREAM)
235     response_iterator = multi_callable(message, **multicallable_kwargs)
236     for response in response_iterator:
237         if response != message:
238             raise RuntimeError("Request '{}' != Response '{}'".format(
239                 message, response))
240
241
242 def _stream_unary_client(channel, multicallable_kwargs, message):
243     multi_callable = channel.stream_unary(_STREAM_UNARY)
244     requests = (_REQUEST for _ in range(_STREAM_LENGTH))
245     response = multi_callable(requests, **multicallable_kwargs)
246     if response != message:
247         raise RuntimeError("Request '{}' != Response '{}'".format(
248             message, response))
249
250
251 def _stream_stream_client(channel, multicallable_kwargs, message):
252     multi_callable = channel.stream_stream(_STREAM_STREAM)
253     request_prefix = str(0).encode('ascii') * 100
254     requests = (
255         request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH))
256     response_iterator = multi_callable(requests, **multicallable_kwargs)
257     for i, response in enumerate(response_iterator):
258         if int(response.decode('ascii')) != i:
259             raise RuntimeError("Request '{}' != Response '{}'".format(
260                 i, response))
261
262
263 class CompressionTest(unittest.TestCase):
264
265     def assertCompressed(self, compression_ratio):
266         self.assertLess(
267             compression_ratio,
268             -1.0 * _COMPRESSION_RATIO_THRESHOLD,
269             msg='Actual compression ratio: {}'.format(compression_ratio))
270
271     def assertNotCompressed(self, compression_ratio):
272         self.assertGreaterEqual(
273             compression_ratio,
274             -1.0 * _COMPRESSION_RATIO_THRESHOLD,
275             msg='Actual compession ratio: {}'.format(compression_ratio))
276
277     def assertConfigurationCompressed(self, client_streaming, server_streaming,
278                                       channel_compression,
279                                       multicallable_compression,
280                                       server_compression,
281                                       server_call_compression):
282         client_side_compressed = channel_compression or multicallable_compression
283         server_side_compressed = server_compression or server_call_compression
284         channel_kwargs = {
285             'compression': channel_compression,
286         } if channel_compression else {}
287         multicallable_kwargs = {
288             'compression': multicallable_compression,
289         } if multicallable_compression else {}
290
291         client_function = None
292         if not client_streaming and not server_streaming:
293             client_function = _unary_unary_client
294         elif not client_streaming and server_streaming:
295             client_function = _unary_stream_client
296         elif client_streaming and not server_streaming:
297             client_function = _stream_unary_client
298         else:
299             client_function = _stream_stream_client
300
301         server_kwargs = {
302             'compression': server_compression,
303         } if server_compression else {}
304         server_handler = _GenericHandler(
305             functools.partial(set_call_compression, grpc.Compression.Gzip)
306         ) if server_call_compression else _GenericHandler(None)
307         sent_ratio, received_ratio = _get_compression_ratios(
308             client_function, {}, {}, {}, _GenericHandler(None), channel_kwargs,
309             multicallable_kwargs, server_kwargs, server_handler, _REQUEST)
310
311         if client_side_compressed:
312             self.assertCompressed(sent_ratio)
313         else:
314             self.assertNotCompressed(sent_ratio)
315
316         if server_side_compressed:
317             self.assertCompressed(received_ratio)
318         else:
319             self.assertNotCompressed(received_ratio)
320
321     def testDisableNextCompressionStreaming(self):
322         server_kwargs = {
323             'compression': grpc.Compression.Deflate,
324         }
325         _, received_ratio = _get_compression_ratios(
326             _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
327             server_kwargs, _GenericHandler(disable_next_compression), _REQUEST)
328         self.assertNotCompressed(received_ratio)
329
330     def testDisableNextCompressionStreamingResets(self):
331         server_kwargs = {
332             'compression': grpc.Compression.Deflate,
333         }
334         _, received_ratio = _get_compression_ratios(
335             _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
336             server_kwargs, _GenericHandler(disable_first_compression), _REQUEST)
337         self.assertCompressed(received_ratio)
338
339
340 def _get_compression_str(name, value):
341     return '{}{}'.format(name, _COMPRESSION_NAMES[value])
342
343
344 def _get_compression_test_name(client_streaming, server_streaming,
345                                channel_compression, multicallable_compression,
346                                server_compression, server_call_compression):
347     client_arity = 'Stream' if client_streaming else 'Unary'
348     server_arity = 'Stream' if server_streaming else 'Unary'
349     arity = '{}{}'.format(client_arity, server_arity)
350     channel_compression_str = _get_compression_str('Channel',
351                                                    channel_compression)
352     multicallable_compression_str = _get_compression_str(
353         'Multicallable', multicallable_compression)
354     server_compression_str = _get_compression_str('Server', server_compression)
355     server_call_compression_str = _get_compression_str('ServerCall',
356                                                        server_call_compression)
357     return 'test{}{}{}{}{}'.format(arity, channel_compression_str,
358                                    multicallable_compression_str,
359                                    server_compression_str,
360                                    server_call_compression_str)
361
362
363 def _test_options():
364     for test_parameters in itertools.product(*_TEST_OPTIONS.values()):
365         yield dict(zip(_TEST_OPTIONS.keys(), test_parameters))
366
367
368 for options in _test_options():
369
370     def test_compression(**kwargs):
371
372         def _test_compression(self):
373             self.assertConfigurationCompressed(**kwargs)
374
375         return _test_compression
376
377     setattr(CompressionTest, _get_compression_test_name(**options),
378             test_compression(**options))
379
380 if __name__ == '__main__':
381     logging.basicConfig()
382     unittest.main(verbosity=2)