9a30dca7bdf8be5faaa0294e65714c9f6a5c627d
[platform/upstream/grpc.git] / tools / run_tests / xds_k8s_test_driver / framework / test_app / client_app.py
1 # Copyright 2020 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """
15 xDS Test Client.
16
17 TODO(sergiitk): separate XdsTestClient and KubernetesClientRunner to individual
18 modules.
19 """
20 import datetime
21 import functools
22 import logging
23 from typing import Iterable, List, Optional
24
25 from framework.helpers import retryers
26 from framework.infrastructure import gcp
27 from framework.infrastructure import k8s
28 import framework.rpc
29 from framework.rpc import grpc_channelz
30 from framework.rpc import grpc_csds
31 from framework.rpc import grpc_testing
32 from framework.test_app import base_runner
33
34 logger = logging.getLogger(__name__)
35
36 # Type aliases
37 _timedelta = datetime.timedelta
38 _LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
39 _XdsUpdateClientConfigureServiceClient = grpc_testing.XdsUpdateClientConfigureServiceClient
40 _ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
41 _ChannelzChannel = grpc_channelz.Channel
42 _ChannelzChannelState = grpc_channelz.ChannelState
43 _ChannelzSubchannel = grpc_channelz.Subchannel
44 _ChannelzSocket = grpc_channelz.Socket
45 _CsdsClient = grpc_csds.CsdsClient
46
47
48 class XdsTestClient(framework.rpc.grpc.GrpcApp):
49     """
50     Represents RPC services implemented in Client component of the xds test app.
51     https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client
52     """
53
54     def __init__(self,
55                  *,
56                  ip: str,
57                  rpc_port: int,
58                  server_target: str,
59                  rpc_host: Optional[str] = None,
60                  maintenance_port: Optional[int] = None):
61         super().__init__(rpc_host=(rpc_host or ip))
62         self.ip = ip
63         self.rpc_port = rpc_port
64         self.server_target = server_target
65         self.maintenance_port = maintenance_port or rpc_port
66
67     @property
68     @functools.lru_cache(None)
69     def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient:
70         return _LoadBalancerStatsServiceClient(self._make_channel(
71             self.rpc_port))
72
73     @property
74     @functools.lru_cache(None)
75     def update_config(self):
76         return _XdsUpdateClientConfigureServiceClient(
77             self._make_channel(self.rpc_port))
78
79     @property
80     @functools.lru_cache(None)
81     def channelz(self) -> _ChannelzServiceClient:
82         return _ChannelzServiceClient(self._make_channel(self.maintenance_port))
83
84     @property
85     @functools.lru_cache(None)
86     def csds(self) -> _CsdsClient:
87         return _CsdsClient(self._make_channel(self.maintenance_port))
88
89     def get_load_balancer_stats(
90         self,
91         *,
92         num_rpcs: int,
93         timeout_sec: Optional[int] = None,
94     ) -> grpc_testing.LoadBalancerStatsResponse:
95         """
96         Shortcut to LoadBalancerStatsServiceClient.get_client_stats()
97         """
98         return self.load_balancer_stats.get_client_stats(
99             num_rpcs=num_rpcs, timeout_sec=timeout_sec)
100
101     def get_load_balancer_accumulated_stats(
102         self,
103         *,
104         timeout_sec: Optional[int] = None,
105     ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse:
106         """Shortcut to LoadBalancerStatsServiceClient.get_client_accumulated_stats()"""
107         return self.load_balancer_stats.get_client_accumulated_stats(
108             timeout_sec=timeout_sec)
109
110     def wait_for_active_server_channel(self) -> _ChannelzChannel:
111         """Wait for the channel to the server to transition to READY.
112
113         Raises:
114             GrpcApp.NotFound: If the channel never transitioned to READY.
115         """
116         return self.wait_for_server_channel_state(_ChannelzChannelState.READY)
117
118     def get_active_server_channel_socket(self) -> _ChannelzSocket:
119         channel = self.find_server_channel_with_state(
120             _ChannelzChannelState.READY)
121         # Get the first subchannel of the active channel to the server.
122         logger.debug(
123             'Retrieving client -> server socket, '
124             'channel_id: %s, subchannel: %s', channel.ref.channel_id,
125             channel.subchannel_ref[0].name)
126         subchannel, *subchannels = list(
127             self.channelz.list_channel_subchannels(channel))
128         if subchannels:
129             logger.warning('Unexpected subchannels: %r', subchannels)
130         # Get the first socket of the subchannel
131         socket, *sockets = list(
132             self.channelz.list_subchannels_sockets(subchannel))
133         if sockets:
134             logger.warning('Unexpected sockets: %r', subchannels)
135         logger.debug('Found client -> server socket: %s', socket.ref.name)
136         return socket
137
138     def wait_for_server_channel_state(
139             self,
140             state: _ChannelzChannelState,
141             *,
142             timeout: Optional[_timedelta] = None,
143             rpc_deadline: Optional[_timedelta] = None) -> _ChannelzChannel:
144         # When polling for a state, prefer smaller wait times to avoid
145         # exhausting all allowed time on a single long RPC.
146         if rpc_deadline is None:
147             rpc_deadline = _timedelta(seconds=30)
148
149         # Fine-tuned to wait for the channel to the server.
150         retryer = retryers.exponential_retryer_with_timeout(
151             wait_min=_timedelta(seconds=10),
152             wait_max=_timedelta(seconds=25),
153             timeout=_timedelta(minutes=5) if timeout is None else timeout)
154
155         logger.info('Waiting for client %s to report a %s channel to %s',
156                     self.ip, _ChannelzChannelState.Name(state),
157                     self.server_target)
158         channel = retryer(self.find_server_channel_with_state,
159                           state,
160                           rpc_deadline=rpc_deadline)
161         logger.info('Client %s channel to %s transitioned to state %s:\n%s',
162                     self.ip, self.server_target,
163                     _ChannelzChannelState.Name(state), channel)
164         return channel
165
166     def find_server_channel_with_state(
167             self,
168             state: _ChannelzChannelState,
169             *,
170             rpc_deadline: Optional[_timedelta] = None,
171             check_subchannel=True) -> _ChannelzChannel:
172         rpc_params = {}
173         if rpc_deadline is not None:
174             rpc_params['deadline_sec'] = rpc_deadline.total_seconds()
175
176         for channel in self.get_server_channels(**rpc_params):
177             channel_state: _ChannelzChannelState = channel.data.state.state
178             logger.info('Server channel: %s, state: %s', channel.ref.name,
179                         _ChannelzChannelState.Name(channel_state))
180             if channel_state is state:
181                 if check_subchannel:
182                     # When requested, check if the channel has at least
183                     # one subchannel in the requested state.
184                     try:
185                         subchannel = self.find_subchannel_with_state(
186                             channel, state, **rpc_params)
187                         logger.info('Found subchannel in state %s: %s',
188                                     _ChannelzChannelState.Name(state),
189                                     subchannel)
190                     except self.NotFound as e:
191                         # Otherwise, keep searching.
192                         logger.info(e.message)
193                         continue
194                 return channel
195
196         raise self.NotFound(
197             f'Client has no {_ChannelzChannelState.Name(state)} channel with '
198             'the server')
199
200     def get_server_channels(self, **kwargs) -> Iterable[_ChannelzChannel]:
201         return self.channelz.find_channels_for_target(self.server_target,
202                                                       **kwargs)
203
204     def find_subchannel_with_state(self, channel: _ChannelzChannel,
205                                    state: _ChannelzChannelState,
206                                    **kwargs) -> _ChannelzSubchannel:
207         subchannels = self.channelz.list_channel_subchannels(channel, **kwargs)
208         for subchannel in subchannels:
209             if subchannel.data.state.state is state:
210                 return subchannel
211
212         raise self.NotFound(
213             f'Not found a {_ChannelzChannelState.Name(state)} '
214             f'subchannel for channel_id {channel.ref.channel_id}')
215
216     def find_subchannels_with_state(self, state: _ChannelzChannelState,
217                                     **kwargs) -> List[_ChannelzSubchannel]:
218         subchannels = []
219         for channel in self.channelz.find_channels_for_target(
220                 self.server_target, **kwargs):
221             for subchannel in self.channelz.list_channel_subchannels(
222                     channel, **kwargs):
223                 if subchannel.data.state.state is state:
224                     subchannels.append(subchannel)
225         return subchannels
226
227
228 class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
229
230     def __init__(self,
231                  k8s_namespace,
232                  *,
233                  deployment_name,
234                  image_name,
235                  td_bootstrap_image,
236                  gcp_api_manager: gcp.api.GcpApiManager,
237                  gcp_project: str,
238                  gcp_service_account: str,
239                  xds_server_uri=None,
240                  network='default',
241                  service_account_name=None,
242                  stats_port=8079,
243                  deployment_template='client.deployment.yaml',
244                  service_account_template='service-account.yaml',
245                  reuse_namespace=False,
246                  namespace_template=None,
247                  debug_use_port_forwarding=False):
248         super().__init__(k8s_namespace, namespace_template, reuse_namespace)
249
250         # Settings
251         self.deployment_name = deployment_name
252         self.image_name = image_name
253         self.stats_port = stats_port
254         # xDS bootstrap generator
255         self.td_bootstrap_image = td_bootstrap_image
256         self.xds_server_uri = xds_server_uri
257         self.network = network
258         self.deployment_template = deployment_template
259         self.debug_use_port_forwarding = debug_use_port_forwarding
260         # Service account settings:
261         # Kubernetes service account
262         self.service_account_name = service_account_name or deployment_name
263         self.service_account_template = service_account_template
264         # GCP.
265         self.gcp_project = gcp_project
266         self.gcp_ui_url = gcp_api_manager.gcp_ui_url
267         # GCP service account to map to Kubernetes service account
268         self.gcp_service_account = gcp_service_account
269         # GCP IAM API used to grant allow workload service accounts permission
270         # to use GCP service account identity.
271         self.gcp_iam = gcp.iam.IamV1(gcp_api_manager, gcp_project)
272
273         # Mutable state
274         self.deployment: Optional[k8s.V1Deployment] = None
275         self.service_account: Optional[k8s.V1ServiceAccount] = None
276         self.port_forwarder = None
277
278     # TODO(sergiitk): make rpc UnaryCall enum or get it from proto
279     def run(self,
280             *,
281             server_target,
282             rpc='UnaryCall',
283             qps=25,
284             metadata='',
285             secure_mode=False,
286             print_response=False) -> XdsTestClient:
287         logger.info(
288             'Deploying xDS test client "%s" to k8s namespace %s: '
289             'server_target=%s rpc=%s qps=%s metadata=%r secure_mode=%s '
290             'print_response=%s', self.deployment_name, self.k8s_namespace.name,
291             server_target, rpc, qps, metadata, secure_mode, print_response)
292         self._logs_explorer_link(deployment_name=self.deployment_name,
293                                  namespace_name=self.k8s_namespace.name,
294                                  gcp_project=self.gcp_project,
295                                  gcp_ui_url=self.gcp_ui_url)
296
297         super().run()
298
299         # Allow Kubernetes service account to use the GCP service account
300         # identity.
301         self._grant_workload_identity_user(
302             gcp_iam=self.gcp_iam,
303             gcp_service_account=self.gcp_service_account,
304             service_account_name=self.service_account_name)
305
306         # Create service account
307         self.service_account = self._create_service_account(
308             self.service_account_template,
309             service_account_name=self.service_account_name,
310             namespace_name=self.k8s_namespace.name,
311             gcp_service_account=self.gcp_service_account)
312
313         # Always create a new deployment
314         self.deployment = self._create_deployment(
315             self.deployment_template,
316             deployment_name=self.deployment_name,
317             image_name=self.image_name,
318             namespace_name=self.k8s_namespace.name,
319             service_account_name=self.service_account_name,
320             td_bootstrap_image=self.td_bootstrap_image,
321             xds_server_uri=self.xds_server_uri,
322             network=self.network,
323             stats_port=self.stats_port,
324             server_target=server_target,
325             rpc=rpc,
326             qps=qps,
327             metadata=metadata,
328             secure_mode=secure_mode,
329             print_response=print_response)
330
331         self._wait_deployment_with_available_replicas(self.deployment_name)
332
333         # Load test client pod. We need only one client at the moment
334         pod = self.k8s_namespace.list_deployment_pods(self.deployment)[0]
335         self._wait_pod_started(pod.metadata.name)
336         pod_ip = pod.status.pod_ip
337         rpc_host = None
338
339         # Experimental, for local debugging.
340         if self.debug_use_port_forwarding:
341             logger.info('LOCAL DEV MODE: Enabling port forwarding to %s:%s',
342                         pod_ip, self.stats_port)
343             self.port_forwarder = self.k8s_namespace.port_forward_pod(
344                 pod, remote_port=self.stats_port)
345             rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS
346
347         return XdsTestClient(ip=pod_ip,
348                              rpc_port=self.stats_port,
349                              server_target=server_target,
350                              rpc_host=rpc_host)
351
352     def cleanup(self, *, force=False, force_namespace=False):
353         if self.port_forwarder:
354             self.k8s_namespace.port_forward_stop(self.port_forwarder)
355             self.port_forwarder = None
356         if self.deployment or force:
357             self._delete_deployment(self.deployment_name)
358             self.deployment = None
359         if self.service_account or force:
360             self._revoke_workload_identity_user(
361                 gcp_iam=self.gcp_iam,
362                 gcp_service_account=self.gcp_service_account,
363                 service_account_name=self.service_account_name)
364             self._delete_service_account(self.service_account_name)
365             self.service_account = None
366         super().cleanup(force=force_namespace and force)
367
368     @classmethod
369     def make_namespace_name(cls,
370                             resource_prefix: str,
371                             resource_suffix: str,
372                             name: str = 'client') -> str:
373         """A helper to make consistent XdsTestClient kubernetes namespace name
374         for given resource prefix and suffix.
375
376         Note: the idea is to intentionally produce different namespace name for
377         the test server, and the test client, as that closely mimics real-world
378         deployments.
379         """
380         return cls._make_namespace_name(resource_prefix, resource_suffix, name)