Imported Upstream version 1.28.1
[platform/upstream/grpc.git] / src / python / grpcio / grpc / _plugin_wrapping.py
1 # Copyright 2015 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import collections
16 import logging
17 import threading
18
19 import grpc
20 from grpc import _common
21 from grpc._cython import cygrpc
22
23 _LOGGER = logging.getLogger(__name__)
24
25
26 class _AuthMetadataContext(
27         collections.namedtuple('AuthMetadataContext', (
28             'service_url',
29             'method_name',
30         )), grpc.AuthMetadataContext):
31     pass
32
33
34 class _CallbackState(object):
35
36     def __init__(self):
37         self.lock = threading.Lock()
38         self.called = False
39         self.exception = None
40
41
42 class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
43
44     def __init__(self, state, callback):
45         self._state = state
46         self._callback = callback
47
48     def __call__(self, metadata, error):
49         with self._state.lock:
50             if self._state.exception is None:
51                 if self._state.called:
52                     raise RuntimeError(
53                         'AuthMetadataPluginCallback invoked more than once!')
54                 else:
55                     self._state.called = True
56             else:
57                 raise RuntimeError(
58                     'AuthMetadataPluginCallback raised exception "{}"!'.format(
59                         self._state.exception))
60         if error is None:
61             self._callback(metadata, cygrpc.StatusCode.ok, None)
62         else:
63             self._callback(None, cygrpc.StatusCode.internal,
64                            _common.encode(str(error)))
65
66
67 class _Plugin(object):
68
69     def __init__(self, metadata_plugin):
70         self._metadata_plugin = metadata_plugin
71
72     def __call__(self, service_url, method_name, callback):
73         context = _AuthMetadataContext(_common.decode(service_url),
74                                        _common.decode(method_name))
75         callback_state = _CallbackState()
76         try:
77             self._metadata_plugin(
78                 context, _AuthMetadataPluginCallback(callback_state, callback))
79         except Exception as exception:  # pylint: disable=broad-except
80             _LOGGER.exception(
81                 'AuthMetadataPluginCallback "%s" raised exception!',
82                 self._metadata_plugin)
83             with callback_state.lock:
84                 callback_state.exception = exception
85                 if callback_state.called:
86                     return
87             callback(None, cygrpc.StatusCode.internal,
88                      _common.encode(str(exception)))
89
90
91 def metadata_plugin_call_credentials(metadata_plugin, name):
92     if name is None:
93         try:
94             effective_name = metadata_plugin.__name__
95         except AttributeError:
96             effective_name = metadata_plugin.__class__.__name__
97     else:
98         effective_name = name
99     return grpc.CallCredentials(
100         cygrpc.MetadataPluginCallCredentials(_Plugin(metadata_plugin),
101                                              _common.encode(effective_name)))