Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests_py3_only / interop / xds_interop_client.py
1 # Copyright 2020 The gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import argparse
16 import collections
17 import datetime
18 import logging
19 import signal
20 import threading
21 import time
22 import sys
23
24 from typing import DefaultDict, Dict, List, Mapping, Set, Sequence, Tuple
25 import collections
26
27 from concurrent import futures
28
29 import grpc
30
31 from src.proto.grpc.testing import test_pb2
32 from src.proto.grpc.testing import test_pb2_grpc
33 from src.proto.grpc.testing import messages_pb2
34 from src.proto.grpc.testing import empty_pb2
35
36 logger = logging.getLogger()
37 console_handler = logging.StreamHandler()
38 formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s')
39 console_handler.setFormatter(formatter)
40 logger.addHandler(console_handler)
41
42 _SUPPORTED_METHODS = (
43     "UnaryCall",
44     "EmptyCall",
45 )
46
47 _METHOD_CAMEL_TO_CAPS_SNAKE = {
48     "UnaryCall": "UNARY_CALL",
49     "EmptyCall": "EMPTY_CALL",
50 }
51
52 _METHOD_STR_TO_ENUM = {
53     "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL,
54     "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL,
55 }
56
57 _METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}
58
59 PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
60
61 _CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
62
63
64 class _StatsWatcher:
65     _start: int
66     _end: int
67     _rpcs_needed: int
68     _rpcs_by_peer: DefaultDict[str, int]
69     _rpcs_by_method: DefaultDict[str, DefaultDict[str, int]]
70     _no_remote_peer: int
71     _lock: threading.Lock
72     _condition: threading.Condition
73
74     def __init__(self, start: int, end: int):
75         self._start = start
76         self._end = end
77         self._rpcs_needed = end - start
78         self._rpcs_by_peer = collections.defaultdict(int)
79         self._rpcs_by_method = collections.defaultdict(
80             lambda: collections.defaultdict(int))
81         self._condition = threading.Condition()
82         self._no_remote_peer = 0
83
84     def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None:
85         """Records statistics for a single RPC."""
86         if self._start <= request_id < self._end:
87             with self._condition:
88                 if not peer:
89                     self._no_remote_peer += 1
90                 else:
91                     self._rpcs_by_peer[peer] += 1
92                     self._rpcs_by_method[method][peer] += 1
93                 self._rpcs_needed -= 1
94                 self._condition.notify()
95
96     def await_rpc_stats_response(
97             self, timeout_sec: int) -> messages_pb2.LoadBalancerStatsResponse:
98         """Blocks until a full response has been collected."""
99         with self._condition:
100             self._condition.wait_for(lambda: not self._rpcs_needed,
101                                      timeout=float(timeout_sec))
102             response = messages_pb2.LoadBalancerStatsResponse()
103             for peer, count in self._rpcs_by_peer.items():
104                 response.rpcs_by_peer[peer] = count
105             for method, count_by_peer in self._rpcs_by_method.items():
106                 for peer, count in count_by_peer.items():
107                     response.rpcs_by_method[method].rpcs_by_peer[peer] = count
108             response.num_failures = self._no_remote_peer + self._rpcs_needed
109         return response
110
111
112 _global_lock = threading.Lock()
113 _stop_event = threading.Event()
114 _global_rpc_id: int = 0
115 _watchers: Set[_StatsWatcher] = set()
116 _global_server = None
117 _global_rpcs_started: Mapping[str, int] = collections.defaultdict(int)
118 _global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int)
119 _global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int)
120
121 # Mapping[method, Mapping[status_code, count]]
122 _global_rpc_statuses: Mapping[str, Mapping[int, int]] = collections.defaultdict(
123     lambda: collections.defaultdict(int))
124
125
126 def _handle_sigint(sig, frame) -> None:
127     _stop_event.set()
128     _global_server.stop(None)
129
130
131 class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
132                                 ):
133
134     def __init__(self):
135         super(_LoadBalancerStatsServicer).__init__()
136
137     def GetClientStats(
138         self, request: messages_pb2.LoadBalancerStatsRequest,
139         context: grpc.ServicerContext
140     ) -> messages_pb2.LoadBalancerStatsResponse:
141         logger.info("Received stats request.")
142         start = None
143         end = None
144         watcher = None
145         with _global_lock:
146             start = _global_rpc_id + 1
147             end = start + request.num_rpcs
148             watcher = _StatsWatcher(start, end)
149             _watchers.add(watcher)
150         response = watcher.await_rpc_stats_response(request.timeout_sec)
151         with _global_lock:
152             _watchers.remove(watcher)
153         logger.info("Returning stats response: %s", response)
154         return response
155
156     def GetClientAccumulatedStats(
157         self, request: messages_pb2.LoadBalancerAccumulatedStatsRequest,
158         context: grpc.ServicerContext
159     ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse:
160         logger.info("Received cumulative stats request.")
161         response = messages_pb2.LoadBalancerAccumulatedStatsResponse()
162         with _global_lock:
163             for method in _SUPPORTED_METHODS:
164                 caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method]
165                 response.num_rpcs_started_by_method[
166                     caps_method] = _global_rpcs_started[method]
167                 response.num_rpcs_succeeded_by_method[
168                     caps_method] = _global_rpcs_succeeded[method]
169                 response.num_rpcs_failed_by_method[
170                     caps_method] = _global_rpcs_failed[method]
171                 response.stats_per_method[
172                     caps_method].rpcs_started = _global_rpcs_started[method]
173                 for code, count in _global_rpc_statuses[method].items():
174                     response.stats_per_method[caps_method].result[code] = count
175         logger.info("Returning cumulative stats response.")
176         return response
177
178
179 def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]],
180                request_id: int, stub: test_pb2_grpc.TestServiceStub,
181                timeout: float, futures: Mapping[int, Tuple[grpc.Future,
182                                                            str]]) -> None:
183     logger.info(f"Sending {method} request to backend: {request_id}")
184     if method == "UnaryCall":
185         future = stub.UnaryCall.future(messages_pb2.SimpleRequest(),
186                                        metadata=metadata,
187                                        timeout=timeout)
188     elif method == "EmptyCall":
189         future = stub.EmptyCall.future(empty_pb2.Empty(),
190                                        metadata=metadata,
191                                        timeout=timeout)
192     else:
193         raise ValueError(f"Unrecognized method '{method}'.")
194     futures[request_id] = (future, method)
195
196
197 def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str,
198                  print_response: bool) -> None:
199     exception = future.exception()
200     hostname = ""
201     _global_rpc_statuses[method][future.code().value[0]] += 1
202     if exception is not None:
203         with _global_lock:
204             _global_rpcs_failed[method] += 1
205         if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
206             logger.error(f"RPC {rpc_id} timed out")
207         else:
208             logger.error(exception)
209     else:
210         response = future.result()
211         hostname = None
212         for metadatum in future.initial_metadata():
213             if metadatum[0] == "hostname":
214                 hostname = metadatum[1]
215                 break
216         else:
217             hostname = response.hostname
218         if future.code() == grpc.StatusCode.OK:
219             with _global_lock:
220                 _global_rpcs_succeeded[method] += 1
221         else:
222             with _global_lock:
223                 _global_rpcs_failed[method] += 1
224         if print_response:
225             if future.code() == grpc.StatusCode.OK:
226                 logger.info("Successful response.")
227             else:
228                 logger.info(f"RPC failed: {call}")
229     with _global_lock:
230         for watcher in _watchers:
231             watcher.on_rpc_complete(rpc_id, hostname, method)
232
233
234 def _remove_completed_rpcs(futures: Mapping[int, grpc.Future],
235                            print_response: bool) -> None:
236     logger.debug("Removing completed RPCs")
237     done = []
238     for future_id, (future, method) in futures.items():
239         if future.done():
240             _on_rpc_done(future_id, future, method, args.print_response)
241             done.append(future_id)
242     for rpc_id in done:
243         del futures[rpc_id]
244
245
246 def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
247     logger.info("Cancelling all remaining RPCs")
248     for future, _ in futures.values():
249         future.cancel()
250
251
252 class _ChannelConfiguration:
253     """Configuration for a single client channel.
254
255     Instances of this class are meant to be dealt with as PODs. That is,
256     data member should be accessed directly. This class is not thread-safe.
257     When accessing any of its members, the lock member should be held.
258     """
259
260     def __init__(self, method: str, metadata: Sequence[Tuple[str,
261                                                              str]], qps: int,
262                  server: str, rpc_timeout_sec: int, print_response: bool):
263         # condition is signalled when a change is made to the config.
264         self.condition = threading.Condition()
265
266         self.method = method
267         self.metadata = metadata
268         self.qps = qps
269         self.server = server
270         self.rpc_timeout_sec = rpc_timeout_sec
271         self.print_response = print_response
272
273
274 def _run_single_channel(config: _ChannelConfiguration) -> None:
275     global _global_rpc_id  # pylint: disable=global-statement
276     with config.condition:
277         server = config.server
278     with grpc.insecure_channel(server) as channel:
279         stub = test_pb2_grpc.TestServiceStub(channel)
280         futures: Dict[int, Tuple[grpc.Future, str]] = {}
281         while not _stop_event.is_set():
282             with config.condition:
283                 if config.qps == 0:
284                     config.condition.wait(
285                         timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds())
286                     continue
287                 else:
288                     duration_per_query = 1.0 / float(config.qps)
289             request_id = None
290             with _global_lock:
291                 request_id = _global_rpc_id
292                 _global_rpc_id += 1
293                 _global_rpcs_started[config.method] += 1
294             start = time.time()
295             end = start + duration_per_query
296             with config.condition:
297                 _start_rpc(config.method, config.metadata, request_id, stub,
298                            float(config.rpc_timeout_sec), futures)
299             with config.condition:
300                 _remove_completed_rpcs(futures, config.print_response)
301             logger.debug(f"Currently {len(futures)} in-flight RPCs")
302             now = time.time()
303             while now < end:
304                 time.sleep(end - now)
305                 now = time.time()
306         _cancel_all_rpcs(futures)
307
308
309 class _XdsUpdateClientConfigureServicer(
310         test_pb2_grpc.XdsUpdateClientConfigureServiceServicer):
311
312     def __init__(self, per_method_configs: Mapping[str, _ChannelConfiguration],
313                  qps: int):
314         super(_XdsUpdateClientConfigureServicer).__init__()
315         self._per_method_configs = per_method_configs
316         self._qps = qps
317
318     def Configure(
319             self, request: messages_pb2.ClientConfigureRequest,
320             context: grpc.ServicerContext
321     ) -> messages_pb2.ClientConfigureResponse:
322         logger.info("Received Configure RPC: %s", request)
323         method_strs = (_METHOD_ENUM_TO_STR[t] for t in request.types)
324         for method in _SUPPORTED_METHODS:
325             method_enum = _METHOD_STR_TO_ENUM[method]
326             channel_config = self._per_method_configs[method]
327             if method in method_strs:
328                 qps = self._qps
329                 metadata = ((md.key, md.value)
330                             for md in request.metadata
331                             if md.type == method_enum)
332                 # For backward compatibility, do not change timeout when we
333                 # receive a default value timeout.
334                 if request.timeout_sec == 0:
335                     timeout_sec = channel_config.rpc_timeout_sec
336                 else:
337                     timeout_sec = request.timeout_sec
338             else:
339                 qps = 0
340                 metadata = ()
341                 # Leave timeout unchanged for backward compatibility.
342                 timeout_sec = channel_config.rpc_timeout_sec
343             with channel_config.condition:
344                 channel_config.qps = qps
345                 channel_config.metadata = list(metadata)
346                 channel_config.rpc_timeout_sec = timeout_sec
347                 channel_config.condition.notify_all()
348         return messages_pb2.ClientConfigureResponse()
349
350
351 class _MethodHandle:
352     """An object grouping together threads driving RPCs for a method."""
353
354     _channel_threads: List[threading.Thread]
355
356     def __init__(self, num_channels: int,
357                  channel_config: _ChannelConfiguration):
358         """Creates and starts a group of threads running the indicated method."""
359         self._channel_threads = []
360         for i in range(num_channels):
361             thread = threading.Thread(target=_run_single_channel,
362                                       args=(channel_config,))
363             thread.start()
364             self._channel_threads.append(thread)
365
366     def stop(self) -> None:
367         """Joins all threads referenced by the handle."""
368         for channel_thread in self._channel_threads:
369             channel_thread.join()
370
371
372 def _run(args: argparse.Namespace, methods: Sequence[str],
373          per_method_metadata: PerMethodMetadataType) -> None:
374     logger.info("Starting python xDS Interop Client.")
375     global _global_server  # pylint: disable=global-statement
376     method_handles = []
377     channel_configs = {}
378     for method in _SUPPORTED_METHODS:
379         if method in methods:
380             qps = args.qps
381         else:
382             qps = 0
383         channel_config = _ChannelConfiguration(
384             method, per_method_metadata.get(method, []), qps, args.server,
385             args.rpc_timeout_sec, args.print_response)
386         channel_configs[method] = channel_config
387         method_handles.append(_MethodHandle(args.num_channels, channel_config))
388     _global_server = grpc.server(futures.ThreadPoolExecutor())
389     _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
390     test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
391         _LoadBalancerStatsServicer(), _global_server)
392     test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server(
393         _XdsUpdateClientConfigureServicer(channel_configs, args.qps),
394         _global_server)
395     _global_server.start()
396     _global_server.wait_for_termination()
397     for method_handle in method_handles:
398         method_handle.stop()
399
400
401 def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType:
402     metadata = metadata_arg.split(",") if args.metadata else []
403     per_method_metadata = collections.defaultdict(list)
404     for metadatum in metadata:
405         elems = metadatum.split(":")
406         if len(elems) != 3:
407             raise ValueError(
408                 f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'")
409         if elems[0] not in _SUPPORTED_METHODS:
410             raise ValueError(f"Unrecognized method '{elems[0]}'")
411         per_method_metadata[elems[0]].append((elems[1], elems[2]))
412     return per_method_metadata
413
414
415 def parse_rpc_arg(rpc_arg: str) -> Sequence[str]:
416     methods = rpc_arg.split(",")
417     if set(methods) - set(_SUPPORTED_METHODS):
418         raise ValueError("--rpc supported methods: {}".format(
419             ", ".join(_SUPPORTED_METHODS)))
420     return methods
421
422
423 if __name__ == "__main__":
424     parser = argparse.ArgumentParser(
425         description='Run Python XDS interop client.')
426     parser.add_argument(
427         "--num_channels",
428         default=1,
429         type=int,
430         help="The number of channels from which to send requests.")
431     parser.add_argument("--print_response",
432                         default=False,
433                         action="store_true",
434                         help="Write RPC response to STDOUT.")
435     parser.add_argument(
436         "--qps",
437         default=1,
438         type=int,
439         help="The number of queries to send from each channel per second.")
440     parser.add_argument("--rpc_timeout_sec",
441                         default=30,
442                         type=int,
443                         help="The per-RPC timeout in seconds.")
444     parser.add_argument("--server",
445                         default="localhost:50051",
446                         help="The address of the server.")
447     parser.add_argument(
448         "--stats_port",
449         default=50052,
450         type=int,
451         help="The port on which to expose the peer distribution stats service.")
452     parser.add_argument('--verbose',
453                         help='verbose log output',
454                         default=False,
455                         action='store_true')
456     parser.add_argument("--log_file",
457                         default=None,
458                         type=str,
459                         help="A file to log to.")
460     rpc_help = "A comma-delimited list of RPC methods to run. Must be one of "
461     rpc_help += ", ".join(_SUPPORTED_METHODS)
462     rpc_help += "."
463     parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help)
464     metadata_help = (
465         "A comma-delimited list of 3-tuples of the form " +
466         "METHOD:KEY:VALUE, e.g. " +
467         "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3")
468     parser.add_argument("--metadata", default="", type=str, help=metadata_help)
469     args = parser.parse_args()
470     signal.signal(signal.SIGINT, _handle_sigint)
471     if args.verbose:
472         logger.setLevel(logging.DEBUG)
473     if args.log_file:
474         file_handler = logging.FileHandler(args.log_file, mode='a')
475         file_handler.setFormatter(formatter)
476         logger.addHandler(file_handler)
477     _run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata))