1 # Copyright 2020 The gRPC authors.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
24 from typing import DefaultDict, Dict, List, Mapping, Set, Sequence, Tuple
27 from concurrent import futures
31 from src.proto.grpc.testing import test_pb2
32 from src.proto.grpc.testing import test_pb2_grpc
33 from src.proto.grpc.testing import messages_pb2
34 from src.proto.grpc.testing import empty_pb2
36 logger = logging.getLogger()
37 console_handler = logging.StreamHandler()
38 formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s')
39 console_handler.setFormatter(formatter)
40 logger.addHandler(console_handler)
42 _SUPPORTED_METHODS = (
47 _METHOD_CAMEL_TO_CAPS_SNAKE = {
48 "UnaryCall": "UNARY_CALL",
49 "EmptyCall": "EMPTY_CALL",
52 _METHOD_STR_TO_ENUM = {
53 "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL,
54 "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL,
57 _METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}
59 PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
61 _CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
68 _rpcs_by_peer: DefaultDict[str, int]
69 _rpcs_by_method: DefaultDict[str, DefaultDict[str, int]]
72 _condition: threading.Condition
74 def __init__(self, start: int, end: int):
77 self._rpcs_needed = end - start
78 self._rpcs_by_peer = collections.defaultdict(int)
79 self._rpcs_by_method = collections.defaultdict(
80 lambda: collections.defaultdict(int))
81 self._condition = threading.Condition()
82 self._no_remote_peer = 0
84 def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None:
85 """Records statistics for a single RPC."""
86 if self._start <= request_id < self._end:
89 self._no_remote_peer += 1
91 self._rpcs_by_peer[peer] += 1
92 self._rpcs_by_method[method][peer] += 1
93 self._rpcs_needed -= 1
94 self._condition.notify()
96 def await_rpc_stats_response(
97 self, timeout_sec: int) -> messages_pb2.LoadBalancerStatsResponse:
98 """Blocks until a full response has been collected."""
100 self._condition.wait_for(lambda: not self._rpcs_needed,
101 timeout=float(timeout_sec))
102 response = messages_pb2.LoadBalancerStatsResponse()
103 for peer, count in self._rpcs_by_peer.items():
104 response.rpcs_by_peer[peer] = count
105 for method, count_by_peer in self._rpcs_by_method.items():
106 for peer, count in count_by_peer.items():
107 response.rpcs_by_method[method].rpcs_by_peer[peer] = count
108 response.num_failures = self._no_remote_peer + self._rpcs_needed
112 _global_lock = threading.Lock()
113 _stop_event = threading.Event()
114 _global_rpc_id: int = 0
115 _watchers: Set[_StatsWatcher] = set()
116 _global_server = None
117 _global_rpcs_started: Mapping[str, int] = collections.defaultdict(int)
118 _global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int)
119 _global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int)
121 # Mapping[method, Mapping[status_code, count]]
122 _global_rpc_statuses: Mapping[str, Mapping[int, int]] = collections.defaultdict(
123 lambda: collections.defaultdict(int))
126 def _handle_sigint(sig, frame) -> None:
128 _global_server.stop(None)
131 class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
135 super(_LoadBalancerStatsServicer).__init__()
138 self, request: messages_pb2.LoadBalancerStatsRequest,
139 context: grpc.ServicerContext
140 ) -> messages_pb2.LoadBalancerStatsResponse:
141 logger.info("Received stats request.")
146 start = _global_rpc_id + 1
147 end = start + request.num_rpcs
148 watcher = _StatsWatcher(start, end)
149 _watchers.add(watcher)
150 response = watcher.await_rpc_stats_response(request.timeout_sec)
152 _watchers.remove(watcher)
153 logger.info("Returning stats response: %s", response)
156 def GetClientAccumulatedStats(
157 self, request: messages_pb2.LoadBalancerAccumulatedStatsRequest,
158 context: grpc.ServicerContext
159 ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse:
160 logger.info("Received cumulative stats request.")
161 response = messages_pb2.LoadBalancerAccumulatedStatsResponse()
163 for method in _SUPPORTED_METHODS:
164 caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method]
165 response.num_rpcs_started_by_method[
166 caps_method] = _global_rpcs_started[method]
167 response.num_rpcs_succeeded_by_method[
168 caps_method] = _global_rpcs_succeeded[method]
169 response.num_rpcs_failed_by_method[
170 caps_method] = _global_rpcs_failed[method]
171 response.stats_per_method[
172 caps_method].rpcs_started = _global_rpcs_started[method]
173 for code, count in _global_rpc_statuses[method].items():
174 response.stats_per_method[caps_method].result[code] = count
175 logger.info("Returning cumulative stats response.")
179 def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]],
180 request_id: int, stub: test_pb2_grpc.TestServiceStub,
181 timeout: float, futures: Mapping[int, Tuple[grpc.Future,
183 logger.info(f"Sending {method} request to backend: {request_id}")
184 if method == "UnaryCall":
185 future = stub.UnaryCall.future(messages_pb2.SimpleRequest(),
188 elif method == "EmptyCall":
189 future = stub.EmptyCall.future(empty_pb2.Empty(),
193 raise ValueError(f"Unrecognized method '{method}'.")
194 futures[request_id] = (future, method)
197 def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str,
198 print_response: bool) -> None:
199 exception = future.exception()
201 _global_rpc_statuses[method][future.code().value[0]] += 1
202 if exception is not None:
204 _global_rpcs_failed[method] += 1
205 if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
206 logger.error(f"RPC {rpc_id} timed out")
208 logger.error(exception)
210 response = future.result()
212 for metadatum in future.initial_metadata():
213 if metadatum[0] == "hostname":
214 hostname = metadatum[1]
217 hostname = response.hostname
218 if future.code() == grpc.StatusCode.OK:
220 _global_rpcs_succeeded[method] += 1
223 _global_rpcs_failed[method] += 1
225 if future.code() == grpc.StatusCode.OK:
226 logger.info("Successful response.")
228 logger.info(f"RPC failed: {call}")
230 for watcher in _watchers:
231 watcher.on_rpc_complete(rpc_id, hostname, method)
234 def _remove_completed_rpcs(futures: Mapping[int, grpc.Future],
235 print_response: bool) -> None:
236 logger.debug("Removing completed RPCs")
238 for future_id, (future, method) in futures.items():
240 _on_rpc_done(future_id, future, method, args.print_response)
241 done.append(future_id)
246 def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
247 logger.info("Cancelling all remaining RPCs")
248 for future, _ in futures.values():
252 class _ChannelConfiguration:
253 """Configuration for a single client channel.
255 Instances of this class are meant to be dealt with as PODs. That is,
256 data member should be accessed directly. This class is not thread-safe.
257 When accessing any of its members, the lock member should be held.
260 def __init__(self, method: str, metadata: Sequence[Tuple[str,
262 server: str, rpc_timeout_sec: int, print_response: bool):
263 # condition is signalled when a change is made to the config.
264 self.condition = threading.Condition()
267 self.metadata = metadata
270 self.rpc_timeout_sec = rpc_timeout_sec
271 self.print_response = print_response
274 def _run_single_channel(config: _ChannelConfiguration) -> None:
275 global _global_rpc_id # pylint: disable=global-statement
276 with config.condition:
277 server = config.server
278 with grpc.insecure_channel(server) as channel:
279 stub = test_pb2_grpc.TestServiceStub(channel)
280 futures: Dict[int, Tuple[grpc.Future, str]] = {}
281 while not _stop_event.is_set():
282 with config.condition:
284 config.condition.wait(
285 timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds())
288 duration_per_query = 1.0 / float(config.qps)
291 request_id = _global_rpc_id
293 _global_rpcs_started[config.method] += 1
295 end = start + duration_per_query
296 with config.condition:
297 _start_rpc(config.method, config.metadata, request_id, stub,
298 float(config.rpc_timeout_sec), futures)
299 with config.condition:
300 _remove_completed_rpcs(futures, config.print_response)
301 logger.debug(f"Currently {len(futures)} in-flight RPCs")
304 time.sleep(end - now)
306 _cancel_all_rpcs(futures)
309 class _XdsUpdateClientConfigureServicer(
310 test_pb2_grpc.XdsUpdateClientConfigureServiceServicer):
312 def __init__(self, per_method_configs: Mapping[str, _ChannelConfiguration],
314 super(_XdsUpdateClientConfigureServicer).__init__()
315 self._per_method_configs = per_method_configs
319 self, request: messages_pb2.ClientConfigureRequest,
320 context: grpc.ServicerContext
321 ) -> messages_pb2.ClientConfigureResponse:
322 logger.info("Received Configure RPC: %s", request)
323 method_strs = (_METHOD_ENUM_TO_STR[t] for t in request.types)
324 for method in _SUPPORTED_METHODS:
325 method_enum = _METHOD_STR_TO_ENUM[method]
326 channel_config = self._per_method_configs[method]
327 if method in method_strs:
329 metadata = ((md.key, md.value)
330 for md in request.metadata
331 if md.type == method_enum)
332 # For backward compatibility, do not change timeout when we
333 # receive a default value timeout.
334 if request.timeout_sec == 0:
335 timeout_sec = channel_config.rpc_timeout_sec
337 timeout_sec = request.timeout_sec
341 # Leave timeout unchanged for backward compatibility.
342 timeout_sec = channel_config.rpc_timeout_sec
343 with channel_config.condition:
344 channel_config.qps = qps
345 channel_config.metadata = list(metadata)
346 channel_config.rpc_timeout_sec = timeout_sec
347 channel_config.condition.notify_all()
348 return messages_pb2.ClientConfigureResponse()
352 """An object grouping together threads driving RPCs for a method."""
354 _channel_threads: List[threading.Thread]
356 def __init__(self, num_channels: int,
357 channel_config: _ChannelConfiguration):
358 """Creates and starts a group of threads running the indicated method."""
359 self._channel_threads = []
360 for i in range(num_channels):
361 thread = threading.Thread(target=_run_single_channel,
362 args=(channel_config,))
364 self._channel_threads.append(thread)
366 def stop(self) -> None:
367 """Joins all threads referenced by the handle."""
368 for channel_thread in self._channel_threads:
369 channel_thread.join()
372 def _run(args: argparse.Namespace, methods: Sequence[str],
373 per_method_metadata: PerMethodMetadataType) -> None:
374 logger.info("Starting python xDS Interop Client.")
375 global _global_server # pylint: disable=global-statement
378 for method in _SUPPORTED_METHODS:
379 if method in methods:
383 channel_config = _ChannelConfiguration(
384 method, per_method_metadata.get(method, []), qps, args.server,
385 args.rpc_timeout_sec, args.print_response)
386 channel_configs[method] = channel_config
387 method_handles.append(_MethodHandle(args.num_channels, channel_config))
388 _global_server = grpc.server(futures.ThreadPoolExecutor())
389 _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
390 test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
391 _LoadBalancerStatsServicer(), _global_server)
392 test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server(
393 _XdsUpdateClientConfigureServicer(channel_configs, args.qps),
395 _global_server.start()
396 _global_server.wait_for_termination()
397 for method_handle in method_handles:
401 def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType:
402 metadata = metadata_arg.split(",") if args.metadata else []
403 per_method_metadata = collections.defaultdict(list)
404 for metadatum in metadata:
405 elems = metadatum.split(":")
408 f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'")
409 if elems[0] not in _SUPPORTED_METHODS:
410 raise ValueError(f"Unrecognized method '{elems[0]}'")
411 per_method_metadata[elems[0]].append((elems[1], elems[2]))
412 return per_method_metadata
415 def parse_rpc_arg(rpc_arg: str) -> Sequence[str]:
416 methods = rpc_arg.split(",")
417 if set(methods) - set(_SUPPORTED_METHODS):
418 raise ValueError("--rpc supported methods: {}".format(
419 ", ".join(_SUPPORTED_METHODS)))
423 if __name__ == "__main__":
424 parser = argparse.ArgumentParser(
425 description='Run Python XDS interop client.')
430 help="The number of channels from which to send requests.")
431 parser.add_argument("--print_response",
434 help="Write RPC response to STDOUT.")
439 help="The number of queries to send from each channel per second.")
440 parser.add_argument("--rpc_timeout_sec",
443 help="The per-RPC timeout in seconds.")
444 parser.add_argument("--server",
445 default="localhost:50051",
446 help="The address of the server.")
451 help="The port on which to expose the peer distribution stats service.")
452 parser.add_argument('--verbose',
453 help='verbose log output',
456 parser.add_argument("--log_file",
459 help="A file to log to.")
460 rpc_help = "A comma-delimited list of RPC methods to run. Must be one of "
461 rpc_help += ", ".join(_SUPPORTED_METHODS)
463 parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help)
465 "A comma-delimited list of 3-tuples of the form " +
466 "METHOD:KEY:VALUE, e.g. " +
467 "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3")
468 parser.add_argument("--metadata", default="", type=str, help=metadata_help)
469 args = parser.parse_args()
470 signal.signal(signal.SIGINT, _handle_sigint)
472 logger.setLevel(logging.DEBUG)
474 file_handler = logging.FileHandler(args.log_file, mode='a')
475 file_handler.setFormatter(formatter)
476 logger.addHandler(file_handler)
477 _run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata))