Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / unit / _xds_credentials_test.py
1 # Copyright 2021 The gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Tests xDS server and channel credentials."""
15
16 import unittest
17
18 import logging
19 from concurrent import futures
20 import contextlib
21
22 import grpc
23 import grpc.experimental
24 from tests.unit import test_common
25 from tests.unit import resources
26
27
28 class _GenericHandler(grpc.GenericRpcHandler):
29
30     def service(self, handler_call_details):
31         return grpc.unary_unary_rpc_method_handler(
32             lambda request, unused_context: request)
33
34
35 @contextlib.contextmanager
36 def xds_channel_server_without_xds(server_fallback_creds):
37     server = grpc.server(futures.ThreadPoolExecutor())
38     server.add_generic_rpc_handlers((_GenericHandler(),))
39     server_server_fallback_creds = grpc.ssl_server_credentials(
40         ((resources.private_key(), resources.certificate_chain()),))
41     server_creds = grpc.xds_server_credentials(server_fallback_creds)
42     port = server.add_secure_port("localhost:0", server_creds)
43     server.start()
44     try:
45         yield "localhost:{}".format(port)
46     finally:
47         server.stop(None)
48
49
50 class XdsCredentialsTest(unittest.TestCase):
51
52     def test_xds_creds_fallback_ssl(self):
53         # Since there is no xDS server, the fallback credentials will be used.
54         # In this case, SSL credentials.
55         server_fallback_creds = grpc.ssl_server_credentials(
56             ((resources.private_key(), resources.certificate_chain()),))
57         with xds_channel_server_without_xds(
58                 server_fallback_creds) as server_address:
59             override_options = (("grpc.ssl_target_name_override",
60                                  "foo.test.google.fr"),)
61             channel_fallback_creds = grpc.ssl_channel_credentials(
62                 root_certificates=resources.test_root_certificates(),
63                 private_key=resources.private_key(),
64                 certificate_chain=resources.certificate_chain())
65             channel_creds = grpc.xds_channel_credentials(channel_fallback_creds)
66             with grpc.secure_channel(server_address,
67                                      channel_creds,
68                                      options=override_options) as channel:
69                 request = b"abc"
70                 response = channel.unary_unary("/test/method")(
71                     request, wait_for_ready=True)
72                 self.assertEqual(response, request)
73
74     def test_xds_creds_fallback_insecure(self):
75         # Since there is no xDS server, the fallback credentials will be used.
76         # In this case, insecure.
77         server_fallback_creds = grpc.insecure_server_credentials()
78         with xds_channel_server_without_xds(
79                 server_fallback_creds) as server_address:
80             channel_fallback_creds = grpc.experimental.insecure_channel_credentials(
81             )
82             channel_creds = grpc.xds_channel_credentials(channel_fallback_creds)
83             with grpc.secure_channel(server_address, channel_creds) as channel:
84                 request = b"abc"
85                 response = channel.unary_unary("/test/method")(
86                     request, wait_for_ready=True)
87                 self.assertEqual(response, request)
88
89     def test_start_xds_server(self):
90         server = grpc.server(futures.ThreadPoolExecutor(), xds=True)
91         server.add_generic_rpc_handlers((_GenericHandler(),))
92         server_fallback_creds = grpc.insecure_server_credentials()
93         server_creds = grpc.xds_server_credentials(server_fallback_creds)
94         port = server.add_secure_port("localhost:0", server_creds)
95         server.start()
96         server.stop(None)
97         # No exceptions thrown. A more comprehensive suite of tests will be
98         # provided by the interop tests.
99
100
101 if __name__ == "__main__":
102     logging.basicConfig()
103     unittest.main()