Imported Upstream version 1.21.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / stress / client.py
1 # Copyright 2016 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Entry point for running stress tests."""
15
16 import argparse
17 from concurrent import futures
18 import threading
19
20 import grpc
21 from six.moves import queue
22 from src.proto.grpc.testing import metrics_pb2_grpc
23 from src.proto.grpc.testing import test_pb2_grpc
24
25 from tests.interop import methods
26 from tests.interop import resources
27 from tests.qps import histogram
28 from tests.stress import metrics_server
29 from tests.stress import test_runner
30
31
32 def _args():
33     parser = argparse.ArgumentParser(
34         description='gRPC Python stress test client')
35     parser.add_argument(
36         '--server_addresses',
37         help='comma separated list of hostname:port to run servers on',
38         default='localhost:8080',
39         type=str)
40     parser.add_argument(
41         '--test_cases',
42         help='comma separated list of testcase:weighting of tests to run',
43         default='large_unary:100',
44         type=str)
45     parser.add_argument(
46         '--test_duration_secs',
47         help='number of seconds to run the stress test',
48         default=-1,
49         type=int)
50     parser.add_argument(
51         '--num_channels_per_server',
52         help='number of channels per server',
53         default=1,
54         type=int)
55     parser.add_argument(
56         '--num_stubs_per_channel',
57         help='number of stubs to create per channel',
58         default=1,
59         type=int)
60     parser.add_argument(
61         '--metrics_port',
62         help='the port to listen for metrics requests on',
63         default=8081,
64         type=int)
65     parser.add_argument(
66         '--use_test_ca',
67         help='Whether to use our fake CA. Requires --use_tls=true',
68         default=False,
69         type=bool)
70     parser.add_argument(
71         '--use_tls', help='Whether to use TLS', default=False, type=bool)
72     parser.add_argument(
73         '--server_host_override',
74         help='the server host to which to claim to connect',
75         type=str)
76     return parser.parse_args()
77
78
79 def _test_case_from_arg(test_case_arg):
80     for test_case in methods.TestCase:
81         if test_case_arg == test_case.value:
82             return test_case
83     else:
84         raise ValueError('No test case {}!'.format(test_case_arg))
85
86
87 def _parse_weighted_test_cases(test_case_args):
88     weighted_test_cases = {}
89     for test_case_arg in test_case_args.split(','):
90         name, weight = test_case_arg.split(':', 1)
91         test_case = _test_case_from_arg(name)
92         weighted_test_cases[test_case] = int(weight)
93     return weighted_test_cases
94
95
96 def _get_channel(target, args):
97     if args.use_tls:
98         if args.use_test_ca:
99             root_certificates = resources.test_root_certificates()
100         else:
101             root_certificates = None  # will load default roots.
102         channel_credentials = grpc.ssl_channel_credentials(
103             root_certificates=root_certificates)
104         options = ((
105             'grpc.ssl_target_name_override',
106             args.server_host_override,
107         ),)
108         channel = grpc.secure_channel(
109             target, channel_credentials, options=options)
110     else:
111         channel = grpc.insecure_channel(target)
112
113     # waits for the channel to be ready before we start sending messages
114     grpc.channel_ready_future(channel).result()
115     return channel
116
117
118 def run_test(args):
119     test_cases = _parse_weighted_test_cases(args.test_cases)
120     test_server_targets = args.server_addresses.split(',')
121     # Propagate any client exceptions with a queue
122     exception_queue = queue.Queue()
123     stop_event = threading.Event()
124     hist = histogram.Histogram(1, 1)
125     runners = []
126
127     server = grpc.server(futures.ThreadPoolExecutor(max_workers=25))
128     metrics_pb2_grpc.add_MetricsServiceServicer_to_server(
129         metrics_server.MetricsServer(hist), server)
130     server.add_insecure_port('[::]:{}'.format(args.metrics_port))
131     server.start()
132
133     for test_server_target in test_server_targets:
134         for _ in range(args.num_channels_per_server):
135             channel = _get_channel(test_server_target, args)
136             for _ in range(args.num_stubs_per_channel):
137                 stub = test_pb2_grpc.TestServiceStub(channel)
138                 runner = test_runner.TestRunner(stub, test_cases, hist,
139                                                 exception_queue, stop_event)
140                 runners.append(runner)
141
142     for runner in runners:
143         runner.start()
144     try:
145         timeout_secs = args.test_duration_secs
146         if timeout_secs < 0:
147             timeout_secs = None
148         raise exception_queue.get(block=True, timeout=timeout_secs)
149     except queue.Empty:
150         # No exceptions thrown, success
151         pass
152     finally:
153         stop_event.set()
154         for runner in runners:
155             runner.join()
156         runner = None
157         server.stop(None)
158
159
160 if __name__ == '__main__':
161     run_test(_args())