Imported Upstream version 1.28.1
[platform/upstream/grpc.git] / src / python / grpcio / grpc / _auth.py
1 # Copyright 2016 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """GRPCAuthMetadataPlugins for standard authentication."""
15
16 import inspect
17 from concurrent import futures
18
19 import grpc
20
21
22 def _sign_request(callback, token, error):
23     metadata = (('authorization', 'Bearer {}'.format(token)),)
24     callback(metadata, error)
25
26
27 def _create_get_token_callback(callback):
28
29     def get_token_callback(future):
30         try:
31             access_token = future.result().access_token
32         except Exception as exception:  # pylint: disable=broad-except
33             _sign_request(callback, None, exception)
34         else:
35             _sign_request(callback, access_token, None)
36
37     return get_token_callback
38
39
40 class GoogleCallCredentials(grpc.AuthMetadataPlugin):
41     """Metadata wrapper for GoogleCredentials from the oauth2client library."""
42
43     def __init__(self, credentials):
44         self._credentials = credentials
45         self._pool = futures.ThreadPoolExecutor(max_workers=1)
46
47         # Hack to determine if these are JWT creds and we need to pass
48         # additional_claims when getting a token
49         self._is_jwt = 'additional_claims' in inspect.getargspec(  # pylint: disable=deprecated-method
50             credentials.get_access_token).args
51
52     def __call__(self, context, callback):
53         # MetadataPlugins cannot block (see grpc.beta.interfaces.py)
54         if self._is_jwt:
55             future = self._pool.submit(
56                 self._credentials.get_access_token,
57                 additional_claims={'aud': context.service_url})
58         else:
59             future = self._pool.submit(self._credentials.get_access_token)
60         future.add_done_callback(_create_get_token_callback(callback))
61
62     def __del__(self):
63         self._pool.shutdown(wait=False)
64
65
66 class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
67     """Metadata wrapper for raw access token credentials."""
68
69     def __init__(self, access_token):
70         self._access_token = access_token
71
72     def __call__(self, context, callback):
73         _sign_request(callback, self._access_token, None)