1 # Copyright 2016 gRPC authors.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Tests server and client side metadata API."""
21 from grpc import _channel
23 from tests.unit import test_common
24 from tests.unit.framework.common import test_constants
26 _CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'),
27 ('grpc.secondary_user_agent', 'secondary-agent'))
29 _REQUEST = b'\x00\x00\x00'
30 _RESPONSE = b'\x00\x00\x00'
32 _UNARY_UNARY = '/test/UnaryUnary'
33 _UNARY_STREAM = '/test/UnaryStream'
34 _STREAM_UNARY = '/test/StreamUnary'
35 _STREAM_STREAM = '/test/StreamStream'
37 _INVOCATION_METADATA = (
40 u'invocation-md-value',
43 u'invocation-md-key-bin',
47 _EXPECTED_INVOCATION_METADATA = (
50 'invocation-md-value',
53 'invocation-md-key-bin',
58 _INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'),
59 (u'initial-md-key-bin', b'\x00\x02'))
60 _EXPECTED_INITIAL_METADATA = (
71 _TRAILING_METADATA = (
73 'server-trailing-md-key',
74 'server-trailing-md-value',
77 'server-trailing-md-key-bin',
81 _EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
84 def _user_agent(metadata):
85 for key, val in metadata:
86 if key == 'user-agent':
88 raise KeyError('No user agent!')
91 def validate_client_metadata(test, servicer_context):
92 invocation_metadata = servicer_context.invocation_metadata()
94 test_common.metadata_transmitted(_EXPECTED_INVOCATION_METADATA,
96 user_agent = _user_agent(invocation_metadata)
98 user_agent.startswith('primary-agent ' + _channel._USER_AGENT))
99 test.assertTrue(user_agent.endswith('secondary-agent'))
102 def handle_unary_unary(test, request, servicer_context):
103 validate_client_metadata(test, servicer_context)
104 servicer_context.send_initial_metadata(_INITIAL_METADATA)
105 servicer_context.set_trailing_metadata(_TRAILING_METADATA)
109 def handle_unary_stream(test, request, servicer_context):
110 validate_client_metadata(test, servicer_context)
111 servicer_context.send_initial_metadata(_INITIAL_METADATA)
112 servicer_context.set_trailing_metadata(_TRAILING_METADATA)
113 for _ in range(test_constants.STREAM_LENGTH):
117 def handle_stream_unary(test, request_iterator, servicer_context):
118 validate_client_metadata(test, servicer_context)
119 servicer_context.send_initial_metadata(_INITIAL_METADATA)
120 servicer_context.set_trailing_metadata(_TRAILING_METADATA)
121 # TODO(issue:#6891) We should be able to remove this loop
122 for request in request_iterator:
127 def handle_stream_stream(test, request_iterator, servicer_context):
128 validate_client_metadata(test, servicer_context)
129 servicer_context.send_initial_metadata(_INITIAL_METADATA)
130 servicer_context.set_trailing_metadata(_TRAILING_METADATA)
131 # TODO(issue:#6891) We should be able to remove this loop,
132 # and replace with return; yield
133 for request in request_iterator:
137 class _MethodHandler(grpc.RpcMethodHandler):
139 def __init__(self, test, request_streaming, response_streaming):
140 self.request_streaming = request_streaming
141 self.response_streaming = response_streaming
142 self.request_deserializer = None
143 self.response_serializer = None
144 self.unary_unary = None
145 self.unary_stream = None
146 self.stream_unary = None
147 self.stream_stream = None
148 if self.request_streaming and self.response_streaming:
149 self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
150 elif self.request_streaming:
151 self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
152 elif self.response_streaming:
153 self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
155 self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
158 class _GenericHandler(grpc.GenericRpcHandler):
160 def __init__(self, test):
163 def service(self, handler_call_details):
164 if handler_call_details.method == _UNARY_UNARY:
165 return _MethodHandler(self._test, False, False)
166 elif handler_call_details.method == _UNARY_STREAM:
167 return _MethodHandler(self._test, False, True)
168 elif handler_call_details.method == _STREAM_UNARY:
169 return _MethodHandler(self._test, True, False)
170 elif handler_call_details.method == _STREAM_STREAM:
171 return _MethodHandler(self._test, True, True)
176 class MetadataTest(unittest.TestCase):
179 self._server = test_common.test_server()
180 self._server.add_generic_rpc_handlers(
181 (_GenericHandler(weakref.proxy(self)),))
182 port = self._server.add_insecure_port('[::]:0')
184 self._channel = grpc.insecure_channel('localhost:%d' % port,
185 options=_CHANNEL_ARGS)
189 self._channel.close()
191 def testUnaryUnary(self):
192 multi_callable = self._channel.unary_unary(_UNARY_UNARY)
193 unused_response, call = multi_callable.with_call(
194 _REQUEST, metadata=_INVOCATION_METADATA)
196 test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
197 call.initial_metadata()))
199 test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
200 call.trailing_metadata()))
202 def testUnaryStream(self):
203 multi_callable = self._channel.unary_stream(_UNARY_STREAM)
204 call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
206 test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
207 call.initial_metadata()))
211 test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
212 call.trailing_metadata()))
214 def testStreamUnary(self):
215 multi_callable = self._channel.stream_unary(_STREAM_UNARY)
216 unused_response, call = multi_callable.with_call(
217 iter([_REQUEST] * test_constants.STREAM_LENGTH),
218 metadata=_INVOCATION_METADATA)
220 test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
221 call.initial_metadata()))
223 test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
224 call.trailing_metadata()))
226 def testStreamStream(self):
227 multi_callable = self._channel.stream_stream(_STREAM_STREAM)
228 call = multi_callable(iter([_REQUEST] * test_constants.STREAM_LENGTH),
229 metadata=_INVOCATION_METADATA)
231 test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
232 call.initial_metadata()))
236 test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
237 call.trailing_metadata()))
240 if __name__ == '__main__':
241 logging.basicConfig()
242 unittest.main(verbosity=2)