Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / unit / _cython / _read_some_but_not_all_responses_test.py
1 # Copyright 2016 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Test a corner-case at the level of the Cython API."""
15
16 import threading
17 import unittest
18
19 from grpc._cython import cygrpc
20 from tests.unit._cython import test_utilities
21
22 _EMPTY_FLAGS = 0
23 _EMPTY_METADATA = ()
24
25
26 class _ServerDriver(object):
27
28     def __init__(self, completion_queue, shutdown_tag):
29         self._condition = threading.Condition()
30         self._completion_queue = completion_queue
31         self._shutdown_tag = shutdown_tag
32         self._events = []
33         self._saw_shutdown_tag = False
34
35     def start(self):
36
37         def in_thread():
38             while True:
39                 event = self._completion_queue.poll()
40                 with self._condition:
41                     self._events.append(event)
42                     self._condition.notify()
43                     if event.tag is self._shutdown_tag:
44                         self._saw_shutdown_tag = True
45                         break
46
47         thread = threading.Thread(target=in_thread)
48         thread.start()
49
50     def done(self):
51         with self._condition:
52             return self._saw_shutdown_tag
53
54     def first_event(self):
55         with self._condition:
56             while not self._events:
57                 self._condition.wait()
58             return self._events[0]
59
60     def events(self):
61         with self._condition:
62             while not self._saw_shutdown_tag:
63                 self._condition.wait()
64             return tuple(self._events)
65
66
67 class _QueueDriver(object):
68
69     def __init__(self, condition, completion_queue, due):
70         self._condition = condition
71         self._completion_queue = completion_queue
72         self._due = due
73         self._events = []
74         self._returned = False
75
76     def start(self):
77
78         def in_thread():
79             while True:
80                 event = self._completion_queue.poll()
81                 with self._condition:
82                     self._events.append(event)
83                     self._due.remove(event.tag)
84                     self._condition.notify_all()
85                     if not self._due:
86                         self._returned = True
87                         return
88
89         thread = threading.Thread(target=in_thread)
90         thread.start()
91
92     def done(self):
93         with self._condition:
94             return self._returned
95
96     def event_with_tag(self, tag):
97         with self._condition:
98             while True:
99                 for event in self._events:
100                     if event.tag is tag:
101                         return event
102                 self._condition.wait()
103
104     def events(self):
105         with self._condition:
106             while not self._returned:
107                 self._condition.wait()
108             return tuple(self._events)
109
110
111 class ReadSomeButNotAllResponsesTest(unittest.TestCase):
112
113     def testReadSomeButNotAllResponses(self):
114         server_completion_queue = cygrpc.CompletionQueue()
115         server = cygrpc.Server([(
116             b'grpc.so_reuseport',
117             0,
118         )], False)
119         server.register_completion_queue(server_completion_queue)
120         port = server.add_http2_port(b'[::]:0')
121         server.start()
122         channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set(),
123                                  None)
124
125         server_shutdown_tag = 'server_shutdown_tag'
126         server_driver = _ServerDriver(server_completion_queue,
127                                       server_shutdown_tag)
128         server_driver.start()
129
130         client_condition = threading.Condition()
131         client_due = set()
132
133         server_call_condition = threading.Condition()
134         server_send_initial_metadata_tag = 'server_send_initial_metadata_tag'
135         server_send_first_message_tag = 'server_send_first_message_tag'
136         server_send_second_message_tag = 'server_send_second_message_tag'
137         server_complete_rpc_tag = 'server_complete_rpc_tag'
138         server_call_due = set((
139             server_send_initial_metadata_tag,
140             server_send_first_message_tag,
141             server_send_second_message_tag,
142             server_complete_rpc_tag,
143         ))
144         server_call_completion_queue = cygrpc.CompletionQueue()
145         server_call_driver = _QueueDriver(server_call_condition,
146                                           server_call_completion_queue,
147                                           server_call_due)
148         server_call_driver.start()
149
150         server_rpc_tag = 'server_rpc_tag'
151         request_call_result = server.request_call(server_call_completion_queue,
152                                                   server_completion_queue,
153                                                   server_rpc_tag)
154
155         client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
156         client_complete_rpc_tag = 'client_complete_rpc_tag'
157         client_call = channel.segregated_call(
158             _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, None, (
159                 (
160                     [
161                         cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
162                     ],
163                     client_receive_initial_metadata_tag,
164                 ),
165                 (
166                     [
167                         cygrpc.SendInitialMetadataOperation(
168                             _EMPTY_METADATA, _EMPTY_FLAGS),
169                         cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
170                         cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
171                     ],
172                     client_complete_rpc_tag,
173                 ),
174             ))
175         client_receive_initial_metadata_event_future = test_utilities.SimpleFuture(
176             client_call.next_event)
177
178         server_rpc_event = server_driver.first_event()
179
180         with server_call_condition:
181             server_send_initial_metadata_start_batch_result = (
182                 server_rpc_event.call.start_server_batch([
183                     cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
184                                                         _EMPTY_FLAGS),
185                 ], server_send_initial_metadata_tag))
186             server_send_first_message_start_batch_result = (
187                 server_rpc_event.call.start_server_batch([
188                     cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS),
189                 ], server_send_first_message_tag))
190         server_send_initial_metadata_event = server_call_driver.event_with_tag(
191             server_send_initial_metadata_tag)
192         server_send_first_message_event = server_call_driver.event_with_tag(
193             server_send_first_message_tag)
194         with server_call_condition:
195             server_send_second_message_start_batch_result = (
196                 server_rpc_event.call.start_server_batch([
197                     cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS),
198                 ], server_send_second_message_tag))
199             server_complete_rpc_start_batch_result = (
200                 server_rpc_event.call.start_server_batch([
201                     cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
202                     cygrpc.SendStatusFromServerOperation(
203                         (), cygrpc.StatusCode.ok, b'test details',
204                         _EMPTY_FLAGS),
205                 ], server_complete_rpc_tag))
206         server_send_second_message_event = server_call_driver.event_with_tag(
207             server_send_second_message_tag)
208         server_complete_rpc_event = server_call_driver.event_with_tag(
209             server_complete_rpc_tag)
210         server_call_driver.events()
211
212         client_recieve_initial_metadata_event = client_receive_initial_metadata_event_future.result(
213         )
214
215         client_receive_first_message_tag = 'client_receive_first_message_tag'
216         client_call.operate([
217             cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
218         ], client_receive_first_message_tag)
219         client_receive_first_message_event = client_call.next_event()
220
221         client_call_cancel_result = client_call.cancel(
222             cygrpc.StatusCode.cancelled, 'Cancelled during test!')
223         client_complete_rpc_event = client_call.next_event()
224
225         channel.close(cygrpc.StatusCode.unknown, 'Channel closed!')
226         server.shutdown(server_completion_queue, server_shutdown_tag)
227         server.cancel_all_calls()
228         server_driver.events()
229
230         self.assertEqual(cygrpc.CallError.ok, request_call_result)
231         self.assertEqual(cygrpc.CallError.ok,
232                          server_send_initial_metadata_start_batch_result)
233         self.assertIs(server_rpc_tag, server_rpc_event.tag)
234         self.assertEqual(cygrpc.CompletionType.operation_complete,
235                          server_rpc_event.completion_type)
236         self.assertIsInstance(server_rpc_event.call, cygrpc.Call)
237
238
239 if __name__ == '__main__':
240     unittest.main(verbosity=2)