Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_hdlc / py / pw_hdlc / rpc.py
1 # Copyright 2020 The Pigweed Authors
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
5 # the License at
6 #
7 #     https://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
13 # the License.
14 """Utilities for using HDLC with pw_rpc."""
15
16 from concurrent.futures import ThreadPoolExecutor
17 import logging
18 import sys
19 import threading
20 import time
21 from typing import (Any, BinaryIO, Callable, Dict, Iterable, List, NoReturn,
22                     Optional, Union)
23
24 from pw_protobuf_compiler import python_protos
25 import pw_rpc
26 from pw_rpc import callback_client
27
28 from pw_hdlc.decode import Frame, FrameDecoder
29 from pw_hdlc import encode
30
31 _LOG = logging.getLogger(__name__)
32
33 STDOUT_ADDRESS = 1
34 DEFAULT_ADDRESS = ord('R')
35
36
37 def channel_output(writer: Callable[[bytes], Any],
38                    address: int = DEFAULT_ADDRESS,
39                    delay_s: float = 0) -> Callable[[bytes], None]:
40     """Returns a function that can be used as a channel output for pw_rpc."""
41
42     if delay_s:
43
44         def slow_write(data: bytes) -> None:
45             """Slows down writes in case unbuffered serial is in use."""
46             for byte in data:
47                 time.sleep(delay_s)
48                 writer(bytes([byte]))
49
50         return lambda data: slow_write(encode.ui_frame(address, data))
51
52     def write_hdlc(data: bytes):
53         frame = encode.ui_frame(address, data)
54         _LOG.debug('Write %2d B: %s', len(frame), frame)
55         writer(frame)
56
57     return write_hdlc
58
59
60 def _handle_error(frame: Frame) -> None:
61     _LOG.error('Failed to parse frame: %s', frame.status.value)
62     _LOG.debug('%s', frame.data)
63
64
65 FrameHandlers = Dict[int, Callable[[Frame], Any]]
66
67
68 def read_and_process_data(read: Callable[[], bytes],
69                           on_read_error: Callable[[Exception], Any],
70                           frame_handlers: FrameHandlers,
71                           error_handler: Callable[[Frame],
72                                                   Any] = _handle_error,
73                           handler_threads: Optional[int] = 1) -> NoReturn:
74     """Continuously reads and handles HDLC frames.
75
76     Passes frames to an executor that calls frame handler functions in other
77     threads.
78     """
79     def handle_frame(frame: Frame):
80         try:
81             if not frame.ok():
82                 error_handler(frame)
83                 return
84
85             try:
86                 frame_handlers[frame.address](frame)
87             except KeyError:
88                 _LOG.warning('Unhandled frame for address %d: %s',
89                              frame.address, frame)
90         except:  # pylint: disable=bare-except
91             _LOG.exception('Exception in HDLC frame handler thread')
92
93     decoder = FrameDecoder()
94
95     # Execute callbacks in a ThreadPoolExecutor to decouple reading the input
96     # stream from handling the data. That way, if a handler function takes a
97     # long time or crashes, this reading thread is not interrupted.
98     with ThreadPoolExecutor(max_workers=handler_threads) as executor:
99         while True:
100             try:
101                 data = read()
102             except Exception as exc:  # pylint: disable=broad-except
103                 on_read_error(exc)
104                 continue
105
106             if data:
107                 _LOG.debug('Read %2d B: %s', len(data), data)
108
109                 for frame in decoder.process_valid_frames(data):
110                     executor.submit(handle_frame, frame)
111
112
113 def write_to_file(data: bytes, output: BinaryIO = sys.stdout.buffer):
114     output.write(data + b'\n')
115     output.flush()
116
117
118 def default_channels(write: Callable[[bytes], Any]) -> List[pw_rpc.Channel]:
119     return [pw_rpc.Channel(1, channel_output(write))]
120
121
122 class HdlcRpcClient:
123     """An RPC client configured to run over HDLC."""
124     def __init__(self,
125                  read: Callable[[], bytes],
126                  paths_or_modules: Union[Iterable[python_protos.PathOrModule],
127                                          python_protos.Library],
128                  channels: Iterable[pw_rpc.Channel],
129                  output: Callable[[bytes], Any] = write_to_file,
130                  client_impl: pw_rpc.client.ClientImpl = None):
131         """Creates an RPC client configured to communicate using HDLC.
132
133         Args:
134           read: Function that reads bytes; e.g serial_device.read.
135           paths_or_modules: paths to .proto files or proto modules
136           channel: RPC channels to use for output
137           output: where to write "stdout" output from the device
138         """
139         if isinstance(paths_or_modules, python_protos.Library):
140             self.protos = paths_or_modules
141         else:
142             self.protos = python_protos.Library.from_paths(paths_or_modules)
143
144         if client_impl is None:
145             client_impl = callback_client.Impl()
146
147         self.client = pw_rpc.Client.from_modules(client_impl, channels,
148                                                  self.protos.modules())
149         frame_handlers: FrameHandlers = {
150             DEFAULT_ADDRESS: self._handle_rpc_packet,
151             STDOUT_ADDRESS: lambda frame: output(frame.data),
152         }
153
154         # Start background thread that reads and processes RPC packets.
155         threading.Thread(target=read_and_process_data,
156                          daemon=True,
157                          args=(read, lambda: None, frame_handlers)).start()
158
159     def rpcs(self, channel_id: int = None) -> Any:
160         """Returns object for accessing services on the specified channel.
161
162         This skips some intermediate layers to make it simpler to invoke RPCs
163         from an HdlcRpcClient. If only one channel is in use, the channel ID is
164         not necessary.
165         """
166         if channel_id is None:
167             return next(iter(self.client.channels())).rpcs
168
169         return self.client.channel(channel_id).rpcs
170
171     def _handle_rpc_packet(self, frame: Frame) -> None:
172         if not self.client.process_packet(frame.data):
173             _LOG.error('Packet not handled by RPC client: %s', frame.data)