Imported Upstream version 1.27.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / unit / _metadata_test.py
1 # Copyright 2016 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Tests server and client side metadata API."""
15
16 import unittest
17 import weakref
18 import logging
19
20 import grpc
21 from grpc import _channel
22
23 from tests.unit import test_common
24 from tests.unit.framework.common import test_constants
25
26 _CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'),
27                  ('grpc.secondary_user_agent', 'secondary-agent'))
28
29 _REQUEST = b'\x00\x00\x00'
30 _RESPONSE = b'\x00\x00\x00'
31
32 _UNARY_UNARY = '/test/UnaryUnary'
33 _UNARY_STREAM = '/test/UnaryStream'
34 _STREAM_UNARY = '/test/StreamUnary'
35 _STREAM_STREAM = '/test/StreamStream'
36
37 _INVOCATION_METADATA = (
38     (
39         b'invocation-md-key',
40         u'invocation-md-value',
41     ),
42     (
43         u'invocation-md-key-bin',
44         b'\x00\x01',
45     ),
46 )
47 _EXPECTED_INVOCATION_METADATA = (
48     (
49         'invocation-md-key',
50         'invocation-md-value',
51     ),
52     (
53         'invocation-md-key-bin',
54         b'\x00\x01',
55     ),
56 )
57
58 _INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'),
59                      (u'initial-md-key-bin', b'\x00\x02'))
60 _EXPECTED_INITIAL_METADATA = (
61     (
62         'initial-md-key',
63         'initial-md-value',
64     ),
65     (
66         'initial-md-key-bin',
67         b'\x00\x02',
68     ),
69 )
70
71 _TRAILING_METADATA = (
72     (
73         'server-trailing-md-key',
74         'server-trailing-md-value',
75     ),
76     (
77         'server-trailing-md-key-bin',
78         b'\x00\x03',
79     ),
80 )
81 _EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
82
83
84 def _user_agent(metadata):
85     for key, val in metadata:
86         if key == 'user-agent':
87             return val
88     raise KeyError('No user agent!')
89
90
91 def validate_client_metadata(test, servicer_context):
92     invocation_metadata = servicer_context.invocation_metadata()
93     test.assertTrue(
94         test_common.metadata_transmitted(_EXPECTED_INVOCATION_METADATA,
95                                          invocation_metadata))
96     user_agent = _user_agent(invocation_metadata)
97     test.assertTrue(
98         user_agent.startswith('primary-agent ' + _channel._USER_AGENT))
99     test.assertTrue(user_agent.endswith('secondary-agent'))
100
101
102 def handle_unary_unary(test, request, servicer_context):
103     validate_client_metadata(test, servicer_context)
104     servicer_context.send_initial_metadata(_INITIAL_METADATA)
105     servicer_context.set_trailing_metadata(_TRAILING_METADATA)
106     return _RESPONSE
107
108
109 def handle_unary_stream(test, request, servicer_context):
110     validate_client_metadata(test, servicer_context)
111     servicer_context.send_initial_metadata(_INITIAL_METADATA)
112     servicer_context.set_trailing_metadata(_TRAILING_METADATA)
113     for _ in range(test_constants.STREAM_LENGTH):
114         yield _RESPONSE
115
116
117 def handle_stream_unary(test, request_iterator, servicer_context):
118     validate_client_metadata(test, servicer_context)
119     servicer_context.send_initial_metadata(_INITIAL_METADATA)
120     servicer_context.set_trailing_metadata(_TRAILING_METADATA)
121     # TODO(issue:#6891) We should be able to remove this loop
122     for request in request_iterator:
123         pass
124     return _RESPONSE
125
126
127 def handle_stream_stream(test, request_iterator, servicer_context):
128     validate_client_metadata(test, servicer_context)
129     servicer_context.send_initial_metadata(_INITIAL_METADATA)
130     servicer_context.set_trailing_metadata(_TRAILING_METADATA)
131     # TODO(issue:#6891) We should be able to remove this loop,
132     # and replace with return; yield
133     for request in request_iterator:
134         yield _RESPONSE
135
136
137 class _MethodHandler(grpc.RpcMethodHandler):
138
139     def __init__(self, test, request_streaming, response_streaming):
140         self.request_streaming = request_streaming
141         self.response_streaming = response_streaming
142         self.request_deserializer = None
143         self.response_serializer = None
144         self.unary_unary = None
145         self.unary_stream = None
146         self.stream_unary = None
147         self.stream_stream = None
148         if self.request_streaming and self.response_streaming:
149             self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
150         elif self.request_streaming:
151             self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
152         elif self.response_streaming:
153             self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
154         else:
155             self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
156
157
158 class _GenericHandler(grpc.GenericRpcHandler):
159
160     def __init__(self, test):
161         self._test = test
162
163     def service(self, handler_call_details):
164         if handler_call_details.method == _UNARY_UNARY:
165             return _MethodHandler(self._test, False, False)
166         elif handler_call_details.method == _UNARY_STREAM:
167             return _MethodHandler(self._test, False, True)
168         elif handler_call_details.method == _STREAM_UNARY:
169             return _MethodHandler(self._test, True, False)
170         elif handler_call_details.method == _STREAM_STREAM:
171             return _MethodHandler(self._test, True, True)
172         else:
173             return None
174
175
176 class MetadataTest(unittest.TestCase):
177
178     def setUp(self):
179         self._server = test_common.test_server()
180         self._server.add_generic_rpc_handlers(
181             (_GenericHandler(weakref.proxy(self)),))
182         port = self._server.add_insecure_port('[::]:0')
183         self._server.start()
184         self._channel = grpc.insecure_channel('localhost:%d' % port,
185                                               options=_CHANNEL_ARGS)
186
187     def tearDown(self):
188         self._server.stop(0)
189         self._channel.close()
190
191     def testUnaryUnary(self):
192         multi_callable = self._channel.unary_unary(_UNARY_UNARY)
193         unused_response, call = multi_callable.with_call(
194             _REQUEST, metadata=_INVOCATION_METADATA)
195         self.assertTrue(
196             test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
197                                              call.initial_metadata()))
198         self.assertTrue(
199             test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
200                                              call.trailing_metadata()))
201
202     def testUnaryStream(self):
203         multi_callable = self._channel.unary_stream(_UNARY_STREAM)
204         call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
205         self.assertTrue(
206             test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
207                                              call.initial_metadata()))
208         for _ in call:
209             pass
210         self.assertTrue(
211             test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
212                                              call.trailing_metadata()))
213
214     def testStreamUnary(self):
215         multi_callable = self._channel.stream_unary(_STREAM_UNARY)
216         unused_response, call = multi_callable.with_call(
217             iter([_REQUEST] * test_constants.STREAM_LENGTH),
218             metadata=_INVOCATION_METADATA)
219         self.assertTrue(
220             test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
221                                              call.initial_metadata()))
222         self.assertTrue(
223             test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
224                                              call.trailing_metadata()))
225
226     def testStreamStream(self):
227         multi_callable = self._channel.stream_stream(_STREAM_STREAM)
228         call = multi_callable(iter([_REQUEST] * test_constants.STREAM_LENGTH),
229                               metadata=_INVOCATION_METADATA)
230         self.assertTrue(
231             test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
232                                              call.initial_metadata()))
233         for _ in call:
234             pass
235         self.assertTrue(
236             test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
237                                              call.trailing_metadata()))
238
239
240 if __name__ == '__main__':
241     logging.basicConfig()
242     unittest.main(verbosity=2)