Imported Upstream version 1.27.0
[platform/upstream/grpc.git] / src / python / grpcio / grpc / _server.py
1 # Copyright 2016 gRPC authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Service-side implementation of gRPC Python."""
15
16 import collections
17 import enum
18 import logging
19 import threading
20 import time
21
22 from concurrent import futures
23 import six
24
25 import grpc
26 from grpc import _common
27 from grpc import _compression
28 from grpc import _interceptor
29 from grpc._cython import cygrpc
30
31 _LOGGER = logging.getLogger(__name__)
32
33 _SHUTDOWN_TAG = 'shutdown'
34 _REQUEST_CALL_TAG = 'request_call'
35
36 _RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server'
37 _SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata'
38 _RECEIVE_MESSAGE_TOKEN = 'receive_message'
39 _SEND_MESSAGE_TOKEN = 'send_message'
40 _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
41     'send_initial_metadata * send_message')
42 _SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server'
43 _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
44     'send_initial_metadata * send_status_from_server')
45
46 _OPEN = 'open'
47 _CLOSED = 'closed'
48 _CANCELLED = 'cancelled'
49
50 _EMPTY_FLAGS = 0
51
52 _DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
53 _INF_TIMEOUT = 1e9
54
55
56 def _serialized_request(request_event):
57     return request_event.batch_operations[0].message()
58
59
60 def _application_code(code):
61     cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
62     return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
63
64
65 def _completion_code(state):
66     if state.code is None:
67         return cygrpc.StatusCode.ok
68     else:
69         return _application_code(state.code)
70
71
72 def _abortion_code(state, code):
73     if state.code is None:
74         return code
75     else:
76         return _application_code(state.code)
77
78
79 def _details(state):
80     return b'' if state.details is None else state.details
81
82
83 class _HandlerCallDetails(
84         collections.namedtuple('_HandlerCallDetails', (
85             'method',
86             'invocation_metadata',
87         )), grpc.HandlerCallDetails):
88     pass
89
90
91 class _RPCState(object):
92
93     def __init__(self):
94         self.condition = threading.Condition()
95         self.due = set()
96         self.request = None
97         self.client = _OPEN
98         self.initial_metadata_allowed = True
99         self.compression_algorithm = None
100         self.disable_next_compression = False
101         self.trailing_metadata = None
102         self.code = None
103         self.details = None
104         self.statused = False
105         self.rpc_errors = []
106         self.callbacks = []
107         self.aborted = False
108
109
110 def _raise_rpc_error(state):
111     rpc_error = grpc.RpcError()
112     state.rpc_errors.append(rpc_error)
113     raise rpc_error
114
115
116 def _possibly_finish_call(state, token):
117     state.due.remove(token)
118     if not _is_rpc_state_active(state) and not state.due:
119         callbacks = state.callbacks
120         state.callbacks = None
121         return state, callbacks
122     else:
123         return None, ()
124
125
126 def _send_status_from_server(state, token):
127
128     def send_status_from_server(unused_send_status_from_server_event):
129         with state.condition:
130             return _possibly_finish_call(state, token)
131
132     return send_status_from_server
133
134
135 def _get_initial_metadata(state, metadata):
136     with state.condition:
137         if state.compression_algorithm:
138             compression_metadata = (
139                 _compression.compression_algorithm_to_metadata(
140                     state.compression_algorithm),)
141             if metadata is None:
142                 return compression_metadata
143             else:
144                 return compression_metadata + tuple(metadata)
145         else:
146             return metadata
147
148
149 def _get_initial_metadata_operation(state, metadata):
150     operation = cygrpc.SendInitialMetadataOperation(
151         _get_initial_metadata(state, metadata), _EMPTY_FLAGS)
152     return operation
153
154
155 def _abort(state, call, code, details):
156     if state.client is not _CANCELLED:
157         effective_code = _abortion_code(state, code)
158         effective_details = details if state.details is None else state.details
159         if state.initial_metadata_allowed:
160             operations = (
161                 _get_initial_metadata_operation(state, None),
162                 cygrpc.SendStatusFromServerOperation(state.trailing_metadata,
163                                                      effective_code,
164                                                      effective_details,
165                                                      _EMPTY_FLAGS),
166             )
167             token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
168         else:
169             operations = (cygrpc.SendStatusFromServerOperation(
170                 state.trailing_metadata, effective_code, effective_details,
171                 _EMPTY_FLAGS),)
172             token = _SEND_STATUS_FROM_SERVER_TOKEN
173         call.start_server_batch(operations,
174                                 _send_status_from_server(state, token))
175         state.statused = True
176         state.due.add(token)
177
178
179 def _receive_close_on_server(state):
180
181     def receive_close_on_server(receive_close_on_server_event):
182         with state.condition:
183             if receive_close_on_server_event.batch_operations[0].cancelled():
184                 state.client = _CANCELLED
185             elif state.client is _OPEN:
186                 state.client = _CLOSED
187             state.condition.notify_all()
188             return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
189
190     return receive_close_on_server
191
192
193 def _receive_message(state, call, request_deserializer):
194
195     def receive_message(receive_message_event):
196         serialized_request = _serialized_request(receive_message_event)
197         if serialized_request is None:
198             with state.condition:
199                 if state.client is _OPEN:
200                     state.client = _CLOSED
201                 state.condition.notify_all()
202                 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
203         else:
204             request = _common.deserialize(serialized_request,
205                                           request_deserializer)
206             with state.condition:
207                 if request is None:
208                     _abort(state, call, cygrpc.StatusCode.internal,
209                            b'Exception deserializing request!')
210                 else:
211                     state.request = request
212                 state.condition.notify_all()
213                 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
214
215     return receive_message
216
217
218 def _send_initial_metadata(state):
219
220     def send_initial_metadata(unused_send_initial_metadata_event):
221         with state.condition:
222             return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
223
224     return send_initial_metadata
225
226
227 def _send_message(state, token):
228
229     def send_message(unused_send_message_event):
230         with state.condition:
231             state.condition.notify_all()
232             return _possibly_finish_call(state, token)
233
234     return send_message
235
236
237 class _Context(grpc.ServicerContext):
238
239     def __init__(self, rpc_event, state, request_deserializer):
240         self._rpc_event = rpc_event
241         self._state = state
242         self._request_deserializer = request_deserializer
243
244     def is_active(self):
245         with self._state.condition:
246             return _is_rpc_state_active(self._state)
247
248     def time_remaining(self):
249         return max(self._rpc_event.call_details.deadline - time.time(), 0)
250
251     def cancel(self):
252         self._rpc_event.call.cancel()
253
254     def add_callback(self, callback):
255         with self._state.condition:
256             if self._state.callbacks is None:
257                 return False
258             else:
259                 self._state.callbacks.append(callback)
260                 return True
261
262     def disable_next_message_compression(self):
263         with self._state.condition:
264             self._state.disable_next_compression = True
265
266     def invocation_metadata(self):
267         return self._rpc_event.invocation_metadata
268
269     def peer(self):
270         return _common.decode(self._rpc_event.call.peer())
271
272     def peer_identities(self):
273         return cygrpc.peer_identities(self._rpc_event.call)
274
275     def peer_identity_key(self):
276         id_key = cygrpc.peer_identity_key(self._rpc_event.call)
277         return id_key if id_key is None else _common.decode(id_key)
278
279     def auth_context(self):
280         return {
281             _common.decode(key): value for key, value in six.iteritems(
282                 cygrpc.auth_context(self._rpc_event.call))
283         }
284
285     def set_compression(self, compression):
286         with self._state.condition:
287             self._state.compression_algorithm = compression
288
289     def send_initial_metadata(self, initial_metadata):
290         with self._state.condition:
291             if self._state.client is _CANCELLED:
292                 _raise_rpc_error(self._state)
293             else:
294                 if self._state.initial_metadata_allowed:
295                     operation = _get_initial_metadata_operation(
296                         self._state, initial_metadata)
297                     self._rpc_event.call.start_server_batch(
298                         (operation,), _send_initial_metadata(self._state))
299                     self._state.initial_metadata_allowed = False
300                     self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
301                 else:
302                     raise ValueError('Initial metadata no longer allowed!')
303
304     def set_trailing_metadata(self, trailing_metadata):
305         with self._state.condition:
306             self._state.trailing_metadata = trailing_metadata
307
308     def abort(self, code, details):
309         # treat OK like other invalid arguments: fail the RPC
310         if code == grpc.StatusCode.OK:
311             _LOGGER.error(
312                 'abort() called with StatusCode.OK; returning UNKNOWN')
313             code = grpc.StatusCode.UNKNOWN
314             details = ''
315         with self._state.condition:
316             self._state.code = code
317             self._state.details = _common.encode(details)
318             self._state.aborted = True
319             raise Exception()
320
321     def abort_with_status(self, status):
322         self._state.trailing_metadata = status.trailing_metadata
323         self.abort(status.code, status.details)
324
325     def set_code(self, code):
326         with self._state.condition:
327             self._state.code = code
328
329     def set_details(self, details):
330         with self._state.condition:
331             self._state.details = _common.encode(details)
332
333     def _finalize_state(self):
334         pass
335
336
337 class _RequestIterator(object):
338
339     def __init__(self, state, call, request_deserializer):
340         self._state = state
341         self._call = call
342         self._request_deserializer = request_deserializer
343
344     def _raise_or_start_receive_message(self):
345         if self._state.client is _CANCELLED:
346             _raise_rpc_error(self._state)
347         elif not _is_rpc_state_active(self._state):
348             raise StopIteration()
349         else:
350             self._call.start_server_batch(
351                 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
352                 _receive_message(self._state, self._call,
353                                  self._request_deserializer))
354             self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
355
356     def _look_for_request(self):
357         if self._state.client is _CANCELLED:
358             _raise_rpc_error(self._state)
359         elif (self._state.request is None and
360               _RECEIVE_MESSAGE_TOKEN not in self._state.due):
361             raise StopIteration()
362         else:
363             request = self._state.request
364             self._state.request = None
365             return request
366
367         raise AssertionError()  # should never run
368
369     def _next(self):
370         with self._state.condition:
371             self._raise_or_start_receive_message()
372             while True:
373                 self._state.condition.wait()
374                 request = self._look_for_request()
375                 if request is not None:
376                     return request
377
378     def __iter__(self):
379         return self
380
381     def __next__(self):
382         return self._next()
383
384     def next(self):
385         return self._next()
386
387
388 def _unary_request(rpc_event, state, request_deserializer):
389
390     def unary_request():
391         with state.condition:
392             if not _is_rpc_state_active(state):
393                 return None
394             else:
395                 rpc_event.call.start_server_batch(
396                     (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
397                     _receive_message(state, rpc_event.call,
398                                      request_deserializer))
399                 state.due.add(_RECEIVE_MESSAGE_TOKEN)
400                 while True:
401                     state.condition.wait()
402                     if state.request is None:
403                         if state.client is _CLOSED:
404                             details = '"{}" requires exactly one request message.'.format(
405                                 rpc_event.call_details.method)
406                             _abort(state, rpc_event.call,
407                                    cygrpc.StatusCode.unimplemented,
408                                    _common.encode(details))
409                             return None
410                         elif state.client is _CANCELLED:
411                             return None
412                     else:
413                         request = state.request
414                         state.request = None
415                         return request
416
417     return unary_request
418
419
420 def _call_behavior(rpc_event,
421                    state,
422                    behavior,
423                    argument,
424                    request_deserializer,
425                    send_response_callback=None):
426     from grpc import _create_servicer_context
427     with _create_servicer_context(rpc_event, state,
428                                   request_deserializer) as context:
429         try:
430             response_or_iterator = None
431             if send_response_callback is not None:
432                 response_or_iterator = behavior(argument, context,
433                                                 send_response_callback)
434             else:
435                 response_or_iterator = behavior(argument, context)
436             return response_or_iterator, True
437         except Exception as exception:  # pylint: disable=broad-except
438             with state.condition:
439                 if state.aborted:
440                     _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
441                            b'RPC Aborted')
442                 elif exception not in state.rpc_errors:
443                     details = 'Exception calling application: {}'.format(
444                         exception)
445                     _LOGGER.exception(details)
446                     _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
447                            _common.encode(details))
448             return None, False
449
450
451 def _take_response_from_response_iterator(rpc_event, state, response_iterator):
452     try:
453         return next(response_iterator), True
454     except StopIteration:
455         return None, True
456     except Exception as exception:  # pylint: disable=broad-except
457         with state.condition:
458             if state.aborted:
459                 _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
460                        b'RPC Aborted')
461             elif exception not in state.rpc_errors:
462                 details = 'Exception iterating responses: {}'.format(exception)
463                 _LOGGER.exception(details)
464                 _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
465                        _common.encode(details))
466         return None, False
467
468
469 def _serialize_response(rpc_event, state, response, response_serializer):
470     serialized_response = _common.serialize(response, response_serializer)
471     if serialized_response is None:
472         with state.condition:
473             _abort(state, rpc_event.call, cygrpc.StatusCode.internal,
474                    b'Failed to serialize response!')
475         return None
476     else:
477         return serialized_response
478
479
480 def _get_send_message_op_flags_from_state(state):
481     if state.disable_next_compression:
482         return cygrpc.WriteFlag.no_compress
483     else:
484         return _EMPTY_FLAGS
485
486
487 def _reset_per_message_state(state):
488     with state.condition:
489         state.disable_next_compression = False
490
491
492 def _send_response(rpc_event, state, serialized_response):
493     with state.condition:
494         if not _is_rpc_state_active(state):
495             return False
496         else:
497             if state.initial_metadata_allowed:
498                 operations = (
499                     _get_initial_metadata_operation(state, None),
500                     cygrpc.SendMessageOperation(
501                         serialized_response,
502                         _get_send_message_op_flags_from_state(state)),
503                 )
504                 state.initial_metadata_allowed = False
505                 token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
506             else:
507                 operations = (cygrpc.SendMessageOperation(
508                     serialized_response,
509                     _get_send_message_op_flags_from_state(state)),)
510                 token = _SEND_MESSAGE_TOKEN
511             rpc_event.call.start_server_batch(operations,
512                                               _send_message(state, token))
513             state.due.add(token)
514             _reset_per_message_state(state)
515             while True:
516                 state.condition.wait()
517                 if token not in state.due:
518                     return _is_rpc_state_active(state)
519
520
521 def _status(rpc_event, state, serialized_response):
522     with state.condition:
523         if state.client is not _CANCELLED:
524             code = _completion_code(state)
525             details = _details(state)
526             operations = [
527                 cygrpc.SendStatusFromServerOperation(state.trailing_metadata,
528                                                      code, details,
529                                                      _EMPTY_FLAGS),
530             ]
531             if state.initial_metadata_allowed:
532                 operations.append(_get_initial_metadata_operation(state, None))
533             if serialized_response is not None:
534                 operations.append(
535                     cygrpc.SendMessageOperation(
536                         serialized_response,
537                         _get_send_message_op_flags_from_state(state)))
538             rpc_event.call.start_server_batch(
539                 operations,
540                 _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
541             state.statused = True
542             _reset_per_message_state(state)
543             state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
544
545
546 def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk,
547                             request_deserializer, response_serializer):
548     cygrpc.install_context_from_request_call_event(rpc_event)
549     try:
550         argument = argument_thunk()
551         if argument is not None:
552             response, proceed = _call_behavior(rpc_event, state, behavior,
553                                                argument, request_deserializer)
554             if proceed:
555                 serialized_response = _serialize_response(
556                     rpc_event, state, response, response_serializer)
557                 if serialized_response is not None:
558                     _status(rpc_event, state, serialized_response)
559     finally:
560         cygrpc.uninstall_context()
561
562
563 def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
564                              request_deserializer, response_serializer):
565     cygrpc.install_context_from_request_call_event(rpc_event)
566
567     def send_response(response):
568         if response is None:
569             _status(rpc_event, state, None)
570         else:
571             serialized_response = _serialize_response(rpc_event, state,
572                                                       response,
573                                                       response_serializer)
574             if serialized_response is not None:
575                 _send_response(rpc_event, state, serialized_response)
576
577     try:
578         argument = argument_thunk()
579         if argument is not None:
580             if hasattr(behavior, 'experimental_non_blocking'
581                       ) and behavior.experimental_non_blocking:
582                 _call_behavior(rpc_event,
583                                state,
584                                behavior,
585                                argument,
586                                request_deserializer,
587                                send_response_callback=send_response)
588             else:
589                 response_iterator, proceed = _call_behavior(
590                     rpc_event, state, behavior, argument, request_deserializer)
591                 if proceed:
592                     _send_message_callback_to_blocking_iterator_adapter(
593                         rpc_event, state, send_response, response_iterator)
594     finally:
595         cygrpc.uninstall_context()
596
597
598 def _is_rpc_state_active(state):
599     return state.client is not _CANCELLED and not state.statused
600
601
602 def _send_message_callback_to_blocking_iterator_adapter(rpc_event, state,
603                                                         send_response_callback,
604                                                         response_iterator):
605     while True:
606         response, proceed = _take_response_from_response_iterator(
607             rpc_event, state, response_iterator)
608         if proceed:
609             send_response_callback(response)
610             if not _is_rpc_state_active(state):
611                 break
612         else:
613             break
614
615
616 def _select_thread_pool_for_behavior(behavior, default_thread_pool):
617     if hasattr(behavior, 'experimental_thread_pool') and isinstance(
618             behavior.experimental_thread_pool, futures.ThreadPoolExecutor):
619         return behavior.experimental_thread_pool
620     else:
621         return default_thread_pool
622
623
624 def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool):
625     unary_request = _unary_request(rpc_event, state,
626                                    method_handler.request_deserializer)
627     thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary,
628                                                    default_thread_pool)
629     return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
630                               method_handler.unary_unary, unary_request,
631                               method_handler.request_deserializer,
632                               method_handler.response_serializer)
633
634
635 def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool):
636     unary_request = _unary_request(rpc_event, state,
637                                    method_handler.request_deserializer)
638     thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream,
639                                                    default_thread_pool)
640     return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
641                               method_handler.unary_stream, unary_request,
642                               method_handler.request_deserializer,
643                               method_handler.response_serializer)
644
645
646 def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool):
647     request_iterator = _RequestIterator(state, rpc_event.call,
648                                         method_handler.request_deserializer)
649     thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary,
650                                                    default_thread_pool)
651     return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
652                               method_handler.stream_unary,
653                               lambda: request_iterator,
654                               method_handler.request_deserializer,
655                               method_handler.response_serializer)
656
657
658 def _handle_stream_stream(rpc_event, state, method_handler,
659                           default_thread_pool):
660     request_iterator = _RequestIterator(state, rpc_event.call,
661                                         method_handler.request_deserializer)
662     thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream,
663                                                    default_thread_pool)
664     return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
665                               method_handler.stream_stream,
666                               lambda: request_iterator,
667                               method_handler.request_deserializer,
668                               method_handler.response_serializer)
669
670
671 def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline):
672
673     def query_handlers(handler_call_details):
674         for generic_handler in generic_handlers:
675             method_handler = generic_handler.service(handler_call_details)
676             if method_handler is not None:
677                 return method_handler
678         return None
679
680     handler_call_details = _HandlerCallDetails(
681         _common.decode(rpc_event.call_details.method),
682         rpc_event.invocation_metadata)
683
684     if interceptor_pipeline is not None:
685         return interceptor_pipeline.execute(query_handlers,
686                                             handler_call_details)
687     else:
688         return query_handlers(handler_call_details)
689
690
691 def _reject_rpc(rpc_event, status, details):
692     rpc_state = _RPCState()
693     operations = (
694         _get_initial_metadata_operation(rpc_state, None),
695         cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
696         cygrpc.SendStatusFromServerOperation(None, status, details,
697                                              _EMPTY_FLAGS),
698     )
699     rpc_event.call.start_server_batch(operations, lambda ignored_event: (
700         rpc_state,
701         (),
702     ))
703     return rpc_state
704
705
706 def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
707     state = _RPCState()
708     with state.condition:
709         rpc_event.call.start_server_batch(
710             (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
711             _receive_close_on_server(state))
712         state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
713         if method_handler.request_streaming:
714             if method_handler.response_streaming:
715                 return state, _handle_stream_stream(rpc_event, state,
716                                                     method_handler, thread_pool)
717             else:
718                 return state, _handle_stream_unary(rpc_event, state,
719                                                    method_handler, thread_pool)
720         else:
721             if method_handler.response_streaming:
722                 return state, _handle_unary_stream(rpc_event, state,
723                                                    method_handler, thread_pool)
724             else:
725                 return state, _handle_unary_unary(rpc_event, state,
726                                                   method_handler, thread_pool)
727
728
729 def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool,
730                  concurrency_exceeded):
731     if not rpc_event.success:
732         return None, None
733     if rpc_event.call_details.method is not None:
734         try:
735             method_handler = _find_method_handler(rpc_event, generic_handlers,
736                                                   interceptor_pipeline)
737         except Exception as exception:  # pylint: disable=broad-except
738             details = 'Exception servicing handler: {}'.format(exception)
739             _LOGGER.exception(details)
740             return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown,
741                                b'Error in service handler!'), None
742         if method_handler is None:
743             return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented,
744                                b'Method not found!'), None
745         elif concurrency_exceeded:
746             return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted,
747                                b'Concurrent RPC limit exceeded!'), None
748         else:
749             return _handle_with_method_handler(rpc_event, method_handler,
750                                                thread_pool)
751     else:
752         return None, None
753
754
755 @enum.unique
756 class _ServerStage(enum.Enum):
757     STOPPED = 'stopped'
758     STARTED = 'started'
759     GRACE = 'grace'
760
761
762 class _ServerState(object):
763
764     # pylint: disable=too-many-arguments
765     def __init__(self, completion_queue, server, generic_handlers,
766                  interceptor_pipeline, thread_pool, maximum_concurrent_rpcs):
767         self.lock = threading.RLock()
768         self.completion_queue = completion_queue
769         self.server = server
770         self.generic_handlers = list(generic_handlers)
771         self.interceptor_pipeline = interceptor_pipeline
772         self.thread_pool = thread_pool
773         self.stage = _ServerStage.STOPPED
774         self.termination_event = threading.Event()
775         self.shutdown_events = [self.termination_event]
776         self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
777         self.active_rpc_count = 0
778
779         # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
780         self.rpc_states = set()
781         self.due = set()
782
783         # A "volatile" flag to interrupt the daemon serving thread
784         self.server_deallocated = False
785
786
787 def _add_generic_handlers(state, generic_handlers):
788     with state.lock:
789         state.generic_handlers.extend(generic_handlers)
790
791
792 def _add_insecure_port(state, address):
793     with state.lock:
794         return state.server.add_http2_port(address)
795
796
797 def _add_secure_port(state, address, server_credentials):
798     with state.lock:
799         return state.server.add_http2_port(address,
800                                            server_credentials._credentials)
801
802
803 def _request_call(state):
804     state.server.request_call(state.completion_queue, state.completion_queue,
805                               _REQUEST_CALL_TAG)
806     state.due.add(_REQUEST_CALL_TAG)
807
808
809 # TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
810 def _stop_serving(state):
811     if not state.rpc_states and not state.due:
812         state.server.destroy()
813         for shutdown_event in state.shutdown_events:
814             shutdown_event.set()
815         state.stage = _ServerStage.STOPPED
816         return True
817     else:
818         return False
819
820
821 def _on_call_completed(state):
822     with state.lock:
823         state.active_rpc_count -= 1
824
825
826 def _process_event_and_continue(state, event):
827     should_continue = True
828     if event.tag is _SHUTDOWN_TAG:
829         with state.lock:
830             state.due.remove(_SHUTDOWN_TAG)
831             if _stop_serving(state):
832                 should_continue = False
833     elif event.tag is _REQUEST_CALL_TAG:
834         with state.lock:
835             state.due.remove(_REQUEST_CALL_TAG)
836             concurrency_exceeded = (
837                 state.maximum_concurrent_rpcs is not None and
838                 state.active_rpc_count >= state.maximum_concurrent_rpcs)
839             rpc_state, rpc_future = _handle_call(event, state.generic_handlers,
840                                                  state.interceptor_pipeline,
841                                                  state.thread_pool,
842                                                  concurrency_exceeded)
843             if rpc_state is not None:
844                 state.rpc_states.add(rpc_state)
845             if rpc_future is not None:
846                 state.active_rpc_count += 1
847                 rpc_future.add_done_callback(
848                     lambda unused_future: _on_call_completed(state))
849             if state.stage is _ServerStage.STARTED:
850                 _request_call(state)
851             elif _stop_serving(state):
852                 should_continue = False
853     else:
854         rpc_state, callbacks = event.tag(event)
855         for callback in callbacks:
856             try:
857                 callback()
858             except Exception:  # pylint: disable=broad-except
859                 _LOGGER.exception('Exception calling callback!')
860         if rpc_state is not None:
861             with state.lock:
862                 state.rpc_states.remove(rpc_state)
863                 if _stop_serving(state):
864                     should_continue = False
865     return should_continue
866
867
868 def _serve(state):
869     while True:
870         timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S
871         event = state.completion_queue.poll(timeout)
872         if state.server_deallocated:
873             _begin_shutdown_once(state)
874         if event.completion_type != cygrpc.CompletionType.queue_timeout:
875             if not _process_event_and_continue(state, event):
876                 return
877         # We want to force the deletion of the previous event
878         # ~before~ we poll again; if the event has a reference
879         # to a shutdown Call object, this can induce spinlock.
880         event = None
881
882
883 def _begin_shutdown_once(state):
884     with state.lock:
885         if state.stage is _ServerStage.STARTED:
886             state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
887             state.stage = _ServerStage.GRACE
888             state.due.add(_SHUTDOWN_TAG)
889
890
891 def _stop(state, grace):
892     with state.lock:
893         if state.stage is _ServerStage.STOPPED:
894             shutdown_event = threading.Event()
895             shutdown_event.set()
896             return shutdown_event
897         else:
898             _begin_shutdown_once(state)
899             shutdown_event = threading.Event()
900             state.shutdown_events.append(shutdown_event)
901             if grace is None:
902                 state.server.cancel_all_calls()
903             else:
904
905                 def cancel_all_calls_after_grace():
906                     shutdown_event.wait(timeout=grace)
907                     with state.lock:
908                         state.server.cancel_all_calls()
909
910                 thread = threading.Thread(target=cancel_all_calls_after_grace)
911                 thread.start()
912                 return shutdown_event
913     shutdown_event.wait()
914     return shutdown_event
915
916
917 def _start(state):
918     with state.lock:
919         if state.stage is not _ServerStage.STOPPED:
920             raise ValueError('Cannot start already-started server!')
921         state.server.start()
922         state.stage = _ServerStage.STARTED
923         _request_call(state)
924
925         thread = threading.Thread(target=_serve, args=(state,))
926         thread.daemon = True
927         thread.start()
928
929
930 def _validate_generic_rpc_handlers(generic_rpc_handlers):
931     for generic_rpc_handler in generic_rpc_handlers:
932         service_attribute = getattr(generic_rpc_handler, 'service', None)
933         if service_attribute is None:
934             raise AttributeError(
935                 '"{}" must conform to grpc.GenericRpcHandler type but does '
936                 'not have "service" method!'.format(generic_rpc_handler))
937
938
939 def _augment_options(base_options, compression):
940     compression_option = _compression.create_channel_option(compression)
941     return tuple(base_options) + compression_option
942
943
944 class _Server(grpc.Server):
945
946     # pylint: disable=too-many-arguments
947     def __init__(self, thread_pool, generic_handlers, interceptors, options,
948                  maximum_concurrent_rpcs, compression):
949         completion_queue = cygrpc.CompletionQueue()
950         server = cygrpc.Server(_augment_options(options, compression))
951         server.register_completion_queue(completion_queue)
952         self._state = _ServerState(completion_queue, server, generic_handlers,
953                                    _interceptor.service_pipeline(interceptors),
954                                    thread_pool, maximum_concurrent_rpcs)
955
956     def add_generic_rpc_handlers(self, generic_rpc_handlers):
957         _validate_generic_rpc_handlers(generic_rpc_handlers)
958         _add_generic_handlers(self._state, generic_rpc_handlers)
959
960     def add_insecure_port(self, address):
961         return _add_insecure_port(self._state, _common.encode(address))
962
963     def add_secure_port(self, address, server_credentials):
964         return _add_secure_port(self._state, _common.encode(address),
965                                 server_credentials)
966
967     def start(self):
968         _start(self._state)
969
970     def wait_for_termination(self, timeout=None):
971         # NOTE(https://bugs.python.org/issue35935)
972         # Remove this workaround once threading.Event.wait() is working with
973         # CTRL+C across platforms.
974         return _common.wait(self._state.termination_event.wait,
975                             self._state.termination_event.is_set,
976                             timeout=timeout)
977
978     def stop(self, grace):
979         return _stop(self._state, grace)
980
981     def __del__(self):
982         if hasattr(self, '_state'):
983             # We can not grab a lock in __del__(), so set a flag to signal the
984             # serving daemon thread (if it exists) to initiate shutdown.
985             self._state.server_deallocated = True
986
987
988 def create_server(thread_pool, generic_rpc_handlers, interceptors, options,
989                   maximum_concurrent_rpcs, compression):
990     _validate_generic_rpc_handlers(generic_rpc_handlers)
991     return _Server(thread_pool, generic_rpc_handlers, interceptors, options,
992                    maximum_concurrent_rpcs, compression)