Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / unit / _cython / _common.py
1 # Copyright 2017 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Common utilities for tests of the Cython layer of gRPC Python."""
15
16 import collections
17 import threading
18
19 from grpc._cython import cygrpc
20
21 RPC_COUNT = 4000
22
23 EMPTY_FLAGS = 0
24
25 INVOCATION_METADATA = (
26     ('client-md-key', 'client-md-key'),
27     ('client-md-key-bin', b'\x00\x01' * 3000),
28 )
29
30 INITIAL_METADATA = (
31     ('server-initial-md-key', 'server-initial-md-value'),
32     ('server-initial-md-key-bin', b'\x00\x02' * 3000),
33 )
34
35 TRAILING_METADATA = (
36     ('server-trailing-md-key', 'server-trailing-md-value'),
37     ('server-trailing-md-key-bin', b'\x00\x03' * 3000),
38 )
39
40
41 class QueueDriver(object):
42
43     def __init__(self, condition, completion_queue):
44         self._condition = condition
45         self._completion_queue = completion_queue
46         self._due = collections.defaultdict(int)
47         self._events = collections.defaultdict(list)
48
49     def add_due(self, tags):
50         if not self._due:
51
52             def in_thread():
53                 while True:
54                     event = self._completion_queue.poll()
55                     with self._condition:
56                         self._events[event.tag].append(event)
57                         self._due[event.tag] -= 1
58                         self._condition.notify_all()
59                         if self._due[event.tag] <= 0:
60                             self._due.pop(event.tag)
61                             if not self._due:
62                                 return
63
64             thread = threading.Thread(target=in_thread)
65             thread.start()
66         for tag in tags:
67             self._due[tag] += 1
68
69     def event_with_tag(self, tag):
70         with self._condition:
71             while True:
72                 if self._events[tag]:
73                     return self._events[tag].pop(0)
74                 else:
75                     self._condition.wait()
76
77
78 def execute_many_times(behavior):
79     return tuple(behavior() for _ in range(RPC_COUNT))
80
81
82 class OperationResult(
83         collections.namedtuple('OperationResult', (
84             'start_batch_result',
85             'completion_type',
86             'success',
87         ))):
88     pass
89
90
91 SUCCESSFUL_OPERATION_RESULT = OperationResult(
92     cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True)
93
94
95 class RpcTest(object):
96
97     def setUp(self):
98         self.server_completion_queue = cygrpc.CompletionQueue()
99         self.server = cygrpc.Server([(b'grpc.so_reuseport', 0)], False)
100         self.server.register_completion_queue(self.server_completion_queue)
101         port = self.server.add_http2_port(b'[::]:0')
102         self.server.start()
103         self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [],
104                                       None)
105
106         self._server_shutdown_tag = 'server_shutdown_tag'
107         self.server_condition = threading.Condition()
108         self.server_driver = QueueDriver(self.server_condition,
109                                          self.server_completion_queue)
110         with self.server_condition:
111             self.server_driver.add_due({
112                 self._server_shutdown_tag,
113             })
114
115         self.client_condition = threading.Condition()
116         self.client_completion_queue = cygrpc.CompletionQueue()
117         self.client_driver = QueueDriver(self.client_condition,
118                                          self.client_completion_queue)
119
120     def tearDown(self):
121         self.server.shutdown(self.server_completion_queue,
122                              self._server_shutdown_tag)
123         self.server.cancel_all_calls()