1 # Copyright 2016 gRPC authors.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Tests server and client side compression."""
19 from concurrent import futures
26 from grpc import _grpcio_metadata
28 from tests.unit import test_common
29 from tests.unit.framework.common import test_constants
30 from tests.unit import _tcp_proxy
32 _UNARY_UNARY = '/test/UnaryUnary'
33 _UNARY_STREAM = '/test/UnaryStream'
34 _STREAM_UNARY = '/test/StreamUnary'
35 _STREAM_STREAM = '/test/StreamStream'
37 # Cut down on test time.
38 _STREAM_LENGTH = test_constants.STREAM_LENGTH // 16
42 _REQUEST = b'\x00' * 100
43 _COMPRESSION_RATIO_THRESHOLD = 0.05
44 _COMPRESSION_METHODS = (
46 # Disabled for test tractability.
47 # grpc.Compression.NoCompression,
48 # grpc.Compression.Deflate,
49 grpc.Compression.Gzip,
51 _COMPRESSION_NAMES = {
53 grpc.Compression.NoCompression: 'NoCompression',
54 grpc.Compression.Deflate: 'DeflateCompression',
55 grpc.Compression.Gzip: 'GzipCompression',
59 'client_streaming': (True, False),
60 'server_streaming': (True, False),
61 'channel_compression': _COMPRESSION_METHODS,
62 'multicallable_compression': _COMPRESSION_METHODS,
63 'server_compression': _COMPRESSION_METHODS,
64 'server_call_compression': _COMPRESSION_METHODS,
68 def _make_handle_unary_unary(pre_response_callback):
70 def _handle_unary(request, servicer_context):
71 if pre_response_callback:
72 pre_response_callback(request, servicer_context)
78 def _make_handle_unary_stream(pre_response_callback):
80 def _handle_unary_stream(request, servicer_context):
81 if pre_response_callback:
82 pre_response_callback(request, servicer_context)
83 for _ in range(_STREAM_LENGTH):
86 return _handle_unary_stream
89 def _make_handle_stream_unary(pre_response_callback):
91 def _handle_stream_unary(request_iterator, servicer_context):
92 if pre_response_callback:
93 pre_response_callback(request_iterator, servicer_context)
95 for request in request_iterator:
100 return _handle_stream_unary
103 def _make_handle_stream_stream(pre_response_callback):
105 def _handle_stream(request_iterator, servicer_context):
106 # TODO(issue:#6891) We should be able to remove this loop,
107 # and replace with return; yield
108 for request in request_iterator:
109 if pre_response_callback:
110 pre_response_callback(request, servicer_context)
113 return _handle_stream
116 def set_call_compression(compression_method, request_or_iterator,
118 del request_or_iterator
119 servicer_context.set_compression(compression_method)
122 def disable_next_compression(request, servicer_context):
124 servicer_context.disable_next_message_compression()
127 def disable_first_compression(request, servicer_context):
128 if int(request.decode('ascii')) == 0:
129 servicer_context.disable_next_message_compression()
132 class _MethodHandler(grpc.RpcMethodHandler):
134 def __init__(self, request_streaming, response_streaming,
135 pre_response_callback):
136 self.request_streaming = request_streaming
137 self.response_streaming = response_streaming
138 self.request_deserializer = None
139 self.response_serializer = None
140 self.unary_unary = None
141 self.unary_stream = None
142 self.stream_unary = None
143 self.stream_stream = None
145 if self.request_streaming and self.response_streaming:
146 self.stream_stream = _make_handle_stream_stream(
147 pre_response_callback)
148 elif not self.request_streaming and not self.response_streaming:
149 self.unary_unary = _make_handle_unary_unary(pre_response_callback)
150 elif not self.request_streaming and self.response_streaming:
151 self.unary_stream = _make_handle_unary_stream(pre_response_callback)
153 self.stream_unary = _make_handle_stream_unary(pre_response_callback)
156 class _GenericHandler(grpc.GenericRpcHandler):
158 def __init__(self, pre_response_callback):
159 self._pre_response_callback = pre_response_callback
161 def service(self, handler_call_details):
162 if handler_call_details.method == _UNARY_UNARY:
163 return _MethodHandler(False, False, self._pre_response_callback)
164 elif handler_call_details.method == _UNARY_STREAM:
165 return _MethodHandler(False, True, self._pre_response_callback)
166 elif handler_call_details.method == _STREAM_UNARY:
167 return _MethodHandler(True, False, self._pre_response_callback)
168 elif handler_call_details.method == _STREAM_STREAM:
169 return _MethodHandler(True, True, self._pre_response_callback)
174 @contextlib.contextmanager
175 def _instrumented_client_server_pair(channel_kwargs, server_kwargs,
177 server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs)
178 server.add_generic_rpc_handlers((server_handler,))
179 server_port = server.add_insecure_port('{}:0'.format(_HOST))
181 with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy:
182 proxy_port = proxy.get_port()
183 with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port),
184 **channel_kwargs) as client_channel:
186 yield client_channel, proxy, server
191 def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function,
192 server_kwargs, server_handler, message):
193 with _instrumented_client_server_pair(channel_kwargs, server_kwargs,
194 server_handler) as pipeline:
195 client_channel, proxy, server = pipeline
196 client_function(client_channel, multicallable_kwargs, message)
197 return proxy.get_byte_count()
200 def _get_compression_ratios(client_function, first_channel_kwargs,
201 first_multicallable_kwargs, first_server_kwargs,
202 first_server_handler, second_channel_kwargs,
203 second_multicallable_kwargs, second_server_kwargs,
204 second_server_handler, message):
206 # This test requires the byte length of each connection to be deterministic. As
207 # it turns out, flow control puts bytes on the wire in a nondeterministic
208 # manner. We disable it here in order to measure compression ratios
210 os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] = 'true'
211 first_bytes_sent, first_bytes_received = _get_byte_counts(
212 first_channel_kwargs, first_multicallable_kwargs, client_function,
213 first_server_kwargs, first_server_handler, message)
214 second_bytes_sent, second_bytes_received = _get_byte_counts(
215 second_channel_kwargs, second_multicallable_kwargs, client_function,
216 second_server_kwargs, second_server_handler, message)
217 return ((second_bytes_sent - first_bytes_sent) /
218 float(first_bytes_sent),
219 (second_bytes_received - first_bytes_received) /
220 float(first_bytes_received))
222 del os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL']
225 def _unary_unary_client(channel, multicallable_kwargs, message):
226 multi_callable = channel.unary_unary(_UNARY_UNARY)
227 response = multi_callable(message, **multicallable_kwargs)
228 if response != message:
229 raise RuntimeError("Request '{}' != Response '{}'".format(
233 def _unary_stream_client(channel, multicallable_kwargs, message):
234 multi_callable = channel.unary_stream(_UNARY_STREAM)
235 response_iterator = multi_callable(message, **multicallable_kwargs)
236 for response in response_iterator:
237 if response != message:
238 raise RuntimeError("Request '{}' != Response '{}'".format(
242 def _stream_unary_client(channel, multicallable_kwargs, message):
243 multi_callable = channel.stream_unary(_STREAM_UNARY)
244 requests = (_REQUEST for _ in range(_STREAM_LENGTH))
245 response = multi_callable(requests, **multicallable_kwargs)
246 if response != message:
247 raise RuntimeError("Request '{}' != Response '{}'".format(
251 def _stream_stream_client(channel, multicallable_kwargs, message):
252 multi_callable = channel.stream_stream(_STREAM_STREAM)
253 request_prefix = str(0).encode('ascii') * 100
255 request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH))
256 response_iterator = multi_callable(requests, **multicallable_kwargs)
257 for i, response in enumerate(response_iterator):
258 if int(response.decode('ascii')) != i:
259 raise RuntimeError("Request '{}' != Response '{}'".format(
263 class CompressionTest(unittest.TestCase):
265 def assertCompressed(self, compression_ratio):
268 -1.0 * _COMPRESSION_RATIO_THRESHOLD,
269 msg='Actual compression ratio: {}'.format(compression_ratio))
271 def assertNotCompressed(self, compression_ratio):
272 self.assertGreaterEqual(
274 -1.0 * _COMPRESSION_RATIO_THRESHOLD,
275 msg='Actual compession ratio: {}'.format(compression_ratio))
277 def assertConfigurationCompressed(self, client_streaming, server_streaming,
279 multicallable_compression,
281 server_call_compression):
282 client_side_compressed = channel_compression or multicallable_compression
283 server_side_compressed = server_compression or server_call_compression
285 'compression': channel_compression,
286 } if channel_compression else {}
287 multicallable_kwargs = {
288 'compression': multicallable_compression,
289 } if multicallable_compression else {}
291 client_function = None
292 if not client_streaming and not server_streaming:
293 client_function = _unary_unary_client
294 elif not client_streaming and server_streaming:
295 client_function = _unary_stream_client
296 elif client_streaming and not server_streaming:
297 client_function = _stream_unary_client
299 client_function = _stream_stream_client
302 'compression': server_compression,
303 } if server_compression else {}
304 server_handler = _GenericHandler(
305 functools.partial(set_call_compression, grpc.Compression.Gzip)
306 ) if server_call_compression else _GenericHandler(None)
307 sent_ratio, received_ratio = _get_compression_ratios(
308 client_function, {}, {}, {}, _GenericHandler(None), channel_kwargs,
309 multicallable_kwargs, server_kwargs, server_handler, _REQUEST)
311 if client_side_compressed:
312 self.assertCompressed(sent_ratio)
314 self.assertNotCompressed(sent_ratio)
316 if server_side_compressed:
317 self.assertCompressed(received_ratio)
319 self.assertNotCompressed(received_ratio)
321 def testDisableNextCompressionStreaming(self):
323 'compression': grpc.Compression.Deflate,
325 _, received_ratio = _get_compression_ratios(
326 _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
327 server_kwargs, _GenericHandler(disable_next_compression), _REQUEST)
328 self.assertNotCompressed(received_ratio)
330 def testDisableNextCompressionStreamingResets(self):
332 'compression': grpc.Compression.Deflate,
334 _, received_ratio = _get_compression_ratios(
335 _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
336 server_kwargs, _GenericHandler(disable_first_compression), _REQUEST)
337 self.assertCompressed(received_ratio)
340 def _get_compression_str(name, value):
341 return '{}{}'.format(name, _COMPRESSION_NAMES[value])
344 def _get_compression_test_name(client_streaming, server_streaming,
345 channel_compression, multicallable_compression,
346 server_compression, server_call_compression):
347 client_arity = 'Stream' if client_streaming else 'Unary'
348 server_arity = 'Stream' if server_streaming else 'Unary'
349 arity = '{}{}'.format(client_arity, server_arity)
350 channel_compression_str = _get_compression_str('Channel',
352 multicallable_compression_str = _get_compression_str(
353 'Multicallable', multicallable_compression)
354 server_compression_str = _get_compression_str('Server', server_compression)
355 server_call_compression_str = _get_compression_str('ServerCall',
356 server_call_compression)
357 return 'test{}{}{}{}{}'.format(arity, channel_compression_str,
358 multicallable_compression_str,
359 server_compression_str,
360 server_call_compression_str)
364 for test_parameters in itertools.product(*_TEST_OPTIONS.values()):
365 yield dict(zip(_TEST_OPTIONS.keys(), test_parameters))
368 for options in _test_options():
370 def test_compression(**kwargs):
372 def _test_compression(self):
373 self.assertConfigurationCompressed(**kwargs)
375 return _test_compression
377 setattr(CompressionTest, _get_compression_test_name(**options),
378 test_compression(**options))
380 if __name__ == '__main__':
381 logging.basicConfig()
382 unittest.main(verbosity=2)