Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / tools / run_tests / xds_k8s_test_driver / framework / rpc / grpc_channelz.py
index 3bf4b26..b4e6b18 100644 (file)
@@ -95,22 +95,25 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
                 return server_socket
         return None
 
-    def find_channels_for_target(self, target: str) -> Iterator[Channel]:
-        return (channel for channel in self.list_channels()
+    def find_channels_for_target(self, target: str,
+                                 **kwargs) -> Iterator[Channel]:
+        return (channel for channel in self.list_channels(**kwargs)
                 if channel.data.target == target)
 
-    def find_server_listening_on_port(self, port: int) -> Optional[Server]:
-        for server in self.list_servers():
+    def find_server_listening_on_port(self, port: int,
+                                      **kwargs) -> Optional[Server]:
+        for server in self.list_servers(**kwargs):
             listen_socket_ref: SocketRef
             for listen_socket_ref in server.listen_socket:
-                listen_socket = self.get_socket(listen_socket_ref.socket_id)
+                listen_socket = self.get_socket(listen_socket_ref.socket_id,
+                                                **kwargs)
                 listen_address: Address = listen_socket.local
                 if (self.is_sock_tcpip_address(listen_address) and
                         listen_address.tcpip_address.port == port):
                     return server
         return None
 
-    def list_channels(self) -> Iterator[Channel]:
+    def list_channels(self, **kwargs) -> Iterator[Channel]:
         """
         Iterate over all pages of all root channels.
 
@@ -125,12 +128,13 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
             start += 1
             response = self.call_unary_with_deadline(
                 rpc='GetTopChannels',
-                req=_GetTopChannelsRequest(start_channel_id=start))
+                req=_GetTopChannelsRequest(start_channel_id=start),
+                **kwargs)
             for channel in response.channel:
                 start = max(start, channel.ref.channel_id)
                 yield channel
 
-    def list_servers(self) -> Iterator[Server]:
+    def list_servers(self, **kwargs) -> Iterator[Server]:
         """Iterate over all pages of all servers that exist in the process."""
         start: int = -1
         response: Optional[_GetServersResponse] = None
@@ -139,12 +143,14 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
             # value by adding 1 to the highest seen result ID.
             start += 1
             response = self.call_unary_with_deadline(
-                rpc='GetServers', req=_GetServersRequest(start_server_id=start))
+                rpc='GetServers',
+                req=_GetServersRequest(start_server_id=start),
+                **kwargs)
             for server in response.server:
                 start = max(start, server.ref.server_id)
                 yield server
 
-    def list_server_sockets(self, server: Server) -> Iterator[Socket]:
+    def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]:
         """List all server sockets that exist in server process.
 
         Iterating over the results will resolve additional pages automatically.
@@ -158,39 +164,44 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
             response = self.call_unary_with_deadline(
                 rpc='GetServerSockets',
                 req=_GetServerSocketsRequest(server_id=server.ref.server_id,
-                                             start_socket_id=start))
+                                             start_socket_id=start),
+                **kwargs)
             socket_ref: SocketRef
             for socket_ref in response.socket_ref:
                 start = max(start, socket_ref.socket_id)
                 # Yield actual socket
-                yield self.get_socket(socket_ref.socket_id)
+                yield self.get_socket(socket_ref.socket_id, **kwargs)
 
-    def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]:
+    def list_channel_sockets(self, channel: Channel,
+                             **kwargs) -> Iterator[Socket]:
         """List all sockets of all subchannels of a given channel."""
-        for subchannel in self.list_channel_subchannels(channel):
-            yield from self.list_subchannels_sockets(subchannel)
+        for subchannel in self.list_channel_subchannels(channel, **kwargs):
+            yield from self.list_subchannels_sockets(subchannel, **kwargs)
 
-    def list_channel_subchannels(self,
-                                 channel: Channel) -> Iterator[Subchannel]:
+    def list_channel_subchannels(self, channel: Channel,
+                                 **kwargs) -> Iterator[Subchannel]:
         """List all subchannels of a given channel."""
         for subchannel_ref in channel.subchannel_ref:
-            yield self.get_subchannel(subchannel_ref.subchannel_id)
+            yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs)
 
-    def list_subchannels_sockets(self,
-                                 subchannel: Subchannel) -> Iterator[Socket]:
+    def list_subchannels_sockets(self, subchannel: Subchannel,
+                                 **kwargs) -> Iterator[Socket]:
         """List all sockets of a given subchannel."""
         for socket_ref in subchannel.socket_ref:
-            yield self.get_socket(socket_ref.socket_id)
+            yield self.get_socket(socket_ref.socket_id, **kwargs)
 
-    def get_subchannel(self, subchannel_id) -> Subchannel:
+    def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel:
         """Return a single Subchannel, otherwise raises RpcError."""
         response: _GetSubchannelResponse = self.call_unary_with_deadline(
             rpc='GetSubchannel',
-            req=_GetSubchannelRequest(subchannel_id=subchannel_id))
+            req=_GetSubchannelRequest(subchannel_id=subchannel_id),
+            **kwargs)
         return response.subchannel
 
-    def get_socket(self, socket_id) -> Socket:
+    def get_socket(self, socket_id, **kwargs) -> Socket:
         """Return a single Socket, otherwise raises RpcError."""
         response: _GetSocketResponse = self.call_unary_with_deadline(
-            rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id))
+            rpc='GetSocket',
+            req=_GetSocketRequest(socket_id=socket_id),
+            **kwargs)
         return response.socket