Imported Upstream version 1.27.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / unit / _cython / _cancel_many_calls_test.py
1 # Copyright 2016 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Test making many calls and immediately cancelling most of them."""
15
16 import threading
17 import unittest
18
19 from grpc._cython import cygrpc
20 from grpc.framework.foundation import logging_pool
21 from tests.unit.framework.common import test_constants
22 from tests.unit._cython import test_utilities
23
24 _EMPTY_FLAGS = 0
25 _EMPTY_METADATA = ()
26
27 _SERVER_SHUTDOWN_TAG = 'server_shutdown'
28 _REQUEST_CALL_TAG = 'request_call'
29 _RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server'
30 _RECEIVE_MESSAGE_TAG = 'receive_message'
31 _SERVER_COMPLETE_CALL_TAG = 'server_complete_call'
32
33 _SUCCESS_CALL_FRACTION = 1.0 / 8.0
34 _SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
35 _UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS
36
37
38 class _State(object):
39
40     def __init__(self):
41         self.condition = threading.Condition()
42         self.handlers_released = False
43         self.parked_handlers = 0
44         self.handled_rpcs = 0
45
46
47 def _is_cancellation_event(event):
48     return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
49             event.batch_operations[0].cancelled())
50
51
52 class _Handler(object):
53
54     def __init__(self, state, completion_queue, rpc_event):
55         self._state = state
56         self._lock = threading.Lock()
57         self._completion_queue = completion_queue
58         self._call = rpc_event.call
59
60     def __call__(self):
61         with self._state.condition:
62             self._state.parked_handlers += 1
63             if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
64                 self._state.condition.notify_all()
65             while not self._state.handlers_released:
66                 self._state.condition.wait()
67
68         with self._lock:
69             self._call.start_server_batch(
70                 (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
71                 _RECEIVE_CLOSE_ON_SERVER_TAG)
72             self._call.start_server_batch(
73                 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
74                 _RECEIVE_MESSAGE_TAG)
75         first_event = self._completion_queue.poll()
76         if _is_cancellation_event(first_event):
77             self._completion_queue.poll()
78         else:
79             with self._lock:
80                 operations = (
81                     cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
82                                                         _EMPTY_FLAGS),
83                     cygrpc.SendMessageOperation(b'\x79\x57', _EMPTY_FLAGS),
84                     cygrpc.SendStatusFromServerOperation(
85                         _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
86                         _EMPTY_FLAGS),
87                 )
88                 self._call.start_server_batch(operations,
89                                               _SERVER_COMPLETE_CALL_TAG)
90             self._completion_queue.poll()
91             self._completion_queue.poll()
92
93
94 def _serve(state, server, server_completion_queue, thread_pool):
95     for _ in range(test_constants.RPC_CONCURRENCY):
96         call_completion_queue = cygrpc.CompletionQueue()
97         server.request_call(call_completion_queue, server_completion_queue,
98                             _REQUEST_CALL_TAG)
99         rpc_event = server_completion_queue.poll()
100         thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
101         with state.condition:
102             state.handled_rpcs += 1
103             if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
104                 state.condition.notify_all()
105     server_completion_queue.poll()
106
107
108 class _QueueDriver(object):
109
110     def __init__(self, condition, completion_queue, due):
111         self._condition = condition
112         self._completion_queue = completion_queue
113         self._due = due
114         self._events = []
115         self._returned = False
116
117     def start(self):
118
119         def in_thread():
120             while True:
121                 event = self._completion_queue.poll()
122                 with self._condition:
123                     self._events.append(event)
124                     self._due.remove(event.tag)
125                     self._condition.notify_all()
126                     if not self._due:
127                         self._returned = True
128                         return
129
130         thread = threading.Thread(target=in_thread)
131         thread.start()
132
133     def events(self, at_least):
134         with self._condition:
135             while len(self._events) < at_least:
136                 self._condition.wait()
137             return tuple(self._events)
138
139
140 class CancelManyCallsTest(unittest.TestCase):
141
142     def testCancelManyCalls(self):
143         server_thread_pool = logging_pool.pool(
144             test_constants.THREAD_CONCURRENCY)
145
146         server_completion_queue = cygrpc.CompletionQueue()
147         server = cygrpc.Server([
148             (
149                 b'grpc.so_reuseport',
150                 0,
151             ),
152         ])
153         server.register_completion_queue(server_completion_queue)
154         port = server.add_http2_port(b'[::]:0')
155         server.start()
156         channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None,
157                                  None)
158
159         state = _State()
160
161         server_thread_args = (
162             state,
163             server,
164             server_completion_queue,
165             server_thread_pool,
166         )
167         server_thread = threading.Thread(target=_serve, args=server_thread_args)
168         server_thread.start()
169
170         client_condition = threading.Condition()
171         client_due = set()
172
173         with client_condition:
174             client_calls = []
175             for index in range(test_constants.RPC_CONCURRENCY):
176                 tag = 'client_complete_call_{0:04d}_tag'.format(index)
177                 client_call = channel.integrated_call(
178                     _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA,
179                     None, ((
180                         (
181                             cygrpc.SendInitialMetadataOperation(
182                                 _EMPTY_METADATA, _EMPTY_FLAGS),
183                             cygrpc.SendMessageOperation(b'\x45\x56',
184                                                         _EMPTY_FLAGS),
185                             cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
186                             cygrpc.ReceiveInitialMetadataOperation(
187                                 _EMPTY_FLAGS),
188                             cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
189                             cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
190                         ),
191                         tag,
192                     ),))
193                 client_due.add(tag)
194                 client_calls.append(client_call)
195
196         client_events_future = test_utilities.SimpleFuture(lambda: tuple(
197             channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS)))
198
199         with state.condition:
200             while True:
201                 if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
202                     state.condition.wait()
203                 elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
204                     state.condition.wait()
205                 else:
206                     state.handlers_released = True
207                     state.condition.notify_all()
208                     break
209
210         client_events_future.result()
211         with client_condition:
212             for client_call in client_calls:
213                 client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!')
214         for _ in range(_UNSUCCESSFUL_CALLS):
215             channel.next_call_event()
216
217         channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!')
218         with state.condition:
219             server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
220
221
222 if __name__ == '__main__':
223     unittest.main(verbosity=2)