Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / tools / run_tests / xds_k8s_test_driver / framework / rpc / grpc_channelz.py
1 # Copyright 2020 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """
15 This contains helpers for gRPC services defined in
16 https://github.com/grpc/grpc-proto/blob/master/grpc/channelz/v1/channelz.proto
17 """
18 import ipaddress
19 import logging
20 from typing import Iterator, Optional
21
22 import grpc
23 from grpc_channelz.v1 import channelz_pb2
24 from grpc_channelz.v1 import channelz_pb2_grpc
25
26 import framework.rpc
27
28 logger = logging.getLogger(__name__)
29
30 # Type aliases
31 # Channel
32 Channel = channelz_pb2.Channel
33 ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
34 ChannelState = ChannelConnectivityState.State  # pylint: disable=no-member
35 _GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
36 _GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse
37 # Subchannel
38 Subchannel = channelz_pb2.Subchannel
39 _GetSubchannelRequest = channelz_pb2.GetSubchannelRequest
40 _GetSubchannelResponse = channelz_pb2.GetSubchannelResponse
41 # Server
42 Server = channelz_pb2.Server
43 _GetServersRequest = channelz_pb2.GetServersRequest
44 _GetServersResponse = channelz_pb2.GetServersResponse
45 # Sockets
46 Socket = channelz_pb2.Socket
47 SocketRef = channelz_pb2.SocketRef
48 _GetSocketRequest = channelz_pb2.GetSocketRequest
49 _GetSocketResponse = channelz_pb2.GetSocketResponse
50 Address = channelz_pb2.Address
51 Security = channelz_pb2.Security
52 # Server Sockets
53 _GetServerSocketsRequest = channelz_pb2.GetServerSocketsRequest
54 _GetServerSocketsResponse = channelz_pb2.GetServerSocketsResponse
55
56
57 class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
58     stub: channelz_pb2_grpc.ChannelzStub
59
60     def __init__(self, channel: grpc.Channel):
61         super().__init__(channel, channelz_pb2_grpc.ChannelzStub)
62
63     @staticmethod
64     def is_sock_tcpip_address(address: Address):
65         return address.WhichOneof('address') == 'tcpip_address'
66
67     @staticmethod
68     def is_ipv4(tcpip_address: Address.TcpIpAddress):
69         # According to proto, tcpip_address.ip_address is either IPv4 or IPv6.
70         # Correspondingly, it's either 4 bytes or 16 bytes in length.
71         return len(tcpip_address.ip_address) == 4
72
73     @classmethod
74     def sock_address_to_str(cls, address: Address):
75         if cls.is_sock_tcpip_address(address):
76             tcpip_address: Address.TcpIpAddress = address.tcpip_address
77             if cls.is_ipv4(tcpip_address):
78                 ip = ipaddress.IPv4Address(tcpip_address.ip_address)
79             else:
80                 ip = ipaddress.IPv6Address(tcpip_address.ip_address)
81             return f'{ip}:{tcpip_address.port}'
82         else:
83             raise NotImplementedError('Only tcpip_address implemented')
84
85     @classmethod
86     def sock_addresses_pretty(cls, socket: Socket):
87         return (f'local={cls.sock_address_to_str(socket.local)}, '
88                 f'remote={cls.sock_address_to_str(socket.remote)}')
89
90     @staticmethod
91     def find_server_socket_matching_client(server_sockets: Iterator[Socket],
92                                            client_socket: Socket) -> Socket:
93         for server_socket in server_sockets:
94             if server_socket.remote == client_socket.local:
95                 return server_socket
96         return None
97
98     def find_channels_for_target(self, target: str,
99                                  **kwargs) -> Iterator[Channel]:
100         return (channel for channel in self.list_channels(**kwargs)
101                 if channel.data.target == target)
102
103     def find_server_listening_on_port(self, port: int,
104                                       **kwargs) -> Optional[Server]:
105         for server in self.list_servers(**kwargs):
106             listen_socket_ref: SocketRef
107             for listen_socket_ref in server.listen_socket:
108                 listen_socket = self.get_socket(listen_socket_ref.socket_id,
109                                                 **kwargs)
110                 listen_address: Address = listen_socket.local
111                 if (self.is_sock_tcpip_address(listen_address) and
112                         listen_address.tcpip_address.port == port):
113                     return server
114         return None
115
116     def list_channels(self, **kwargs) -> Iterator[Channel]:
117         """
118         Iterate over all pages of all root channels.
119
120         Root channels are those which application has directly created.
121         This does not include subchannels nor non-top level channels.
122         """
123         start: int = -1
124         response: Optional[_GetTopChannelsResponse] = None
125         while start < 0 or not response.end:
126             # From proto: To request subsequent pages, the client generates this
127             # value by adding 1 to the highest seen result ID.
128             start += 1
129             response = self.call_unary_with_deadline(
130                 rpc='GetTopChannels',
131                 req=_GetTopChannelsRequest(start_channel_id=start),
132                 **kwargs)
133             for channel in response.channel:
134                 start = max(start, channel.ref.channel_id)
135                 yield channel
136
137     def list_servers(self, **kwargs) -> Iterator[Server]:
138         """Iterate over all pages of all servers that exist in the process."""
139         start: int = -1
140         response: Optional[_GetServersResponse] = None
141         while start < 0 or not response.end:
142             # From proto: To request subsequent pages, the client generates this
143             # value by adding 1 to the highest seen result ID.
144             start += 1
145             response = self.call_unary_with_deadline(
146                 rpc='GetServers',
147                 req=_GetServersRequest(start_server_id=start),
148                 **kwargs)
149             for server in response.server:
150                 start = max(start, server.ref.server_id)
151                 yield server
152
153     def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]:
154         """List all server sockets that exist in server process.
155
156         Iterating over the results will resolve additional pages automatically.
157         """
158         start: int = -1
159         response: Optional[_GetServerSocketsResponse] = None
160         while start < 0 or not response.end:
161             # From proto: To request subsequent pages, the client generates this
162             # value by adding 1 to the highest seen result ID.
163             start += 1
164             response = self.call_unary_with_deadline(
165                 rpc='GetServerSockets',
166                 req=_GetServerSocketsRequest(server_id=server.ref.server_id,
167                                              start_socket_id=start),
168                 **kwargs)
169             socket_ref: SocketRef
170             for socket_ref in response.socket_ref:
171                 start = max(start, socket_ref.socket_id)
172                 # Yield actual socket
173                 yield self.get_socket(socket_ref.socket_id, **kwargs)
174
175     def list_channel_sockets(self, channel: Channel,
176                              **kwargs) -> Iterator[Socket]:
177         """List all sockets of all subchannels of a given channel."""
178         for subchannel in self.list_channel_subchannels(channel, **kwargs):
179             yield from self.list_subchannels_sockets(subchannel, **kwargs)
180
181     def list_channel_subchannels(self, channel: Channel,
182                                  **kwargs) -> Iterator[Subchannel]:
183         """List all subchannels of a given channel."""
184         for subchannel_ref in channel.subchannel_ref:
185             yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs)
186
187     def list_subchannels_sockets(self, subchannel: Subchannel,
188                                  **kwargs) -> Iterator[Socket]:
189         """List all sockets of a given subchannel."""
190         for socket_ref in subchannel.socket_ref:
191             yield self.get_socket(socket_ref.socket_id, **kwargs)
192
193     def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel:
194         """Return a single Subchannel, otherwise raises RpcError."""
195         response: _GetSubchannelResponse = self.call_unary_with_deadline(
196             rpc='GetSubchannel',
197             req=_GetSubchannelRequest(subchannel_id=subchannel_id),
198             **kwargs)
199         return response.subchannel
200
201     def get_socket(self, socket_id, **kwargs) -> Socket:
202         """Return a single Socket, otherwise raises RpcError."""
203         response: _GetSocketResponse = self.call_unary_with_deadline(
204             rpc='GetSocket',
205             req=_GetSocketRequest(socket_id=socket_id),
206             **kwargs)
207         return response.socket