Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_tokenizer / py / pw_tokenizer / detokenize.py
1 #!/usr/bin/env python3
2 # Copyright 2020 The Pigweed Authors
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
5 # use this file except in compliance with the License. You may obtain a copy of
6 # the License at
7 #
8 #     https://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations under
14 # the License.
15 r"""Decodes and detokenizes strings from binary or Base64 input.
16
17 The main class provided by this module is the Detokenize class. To use it,
18 construct it with the path to an ELF or CSV database, a tokens.Database,
19 or a file object for an ELF file or CSV. Then, call the detokenize method with
20 encoded messages, one at a time. The detokenize method returns a
21 DetokenizedString object with the result.
22
23 For example,
24
25   from pw_tokenizer import detokenize
26
27   detok = detokenize.Detokenizer('path/to/my/image.elf')
28   print(detok.detokenize(b'\x12\x34\x56\x78\x03hi!'))
29
30 This module also provides a command line interface for decoding and detokenizing
31 messages from a file or stdin.
32 """
33
34 import argparse
35 import base64
36 import binascii
37 from datetime import datetime
38 import io
39 import logging
40 import os
41 from pathlib import Path
42 import re
43 import string
44 import struct
45 import sys
46 import time
47 from typing import (BinaryIO, Callable, Dict, List, Iterable, Iterator, Match,
48                     NamedTuple, Optional, Pattern, Tuple, Union)
49
50 try:
51     from pw_tokenizer import database, decode, tokens
52 except ImportError:
53     # Append this path to the module search path to allow running this module
54     # without installing the pw_tokenizer package.
55     sys.path.append(os.path.dirname(os.path.dirname(
56         os.path.abspath(__file__))))
57     from pw_tokenizer import database, decode, tokens
58
59 ENCODED_TOKEN = struct.Struct('<I')
60 _LOG = logging.getLogger('pw_tokenizer')
61
62
63 class DetokenizedString:
64     """A detokenized string, with all results if there are collisions."""
65     def __init__(self,
66                  token: Optional[int],
67                  format_string_entries: Iterable[tuple],
68                  encoded_message: bytes,
69                  show_errors: bool = False):
70         self.token = token
71         self.encoded_message = encoded_message
72         self._show_errors = show_errors
73
74         self.successes: List[decode.FormattedString] = []
75         self.failures: List[decode.FormattedString] = []
76
77         decode_attempts: List[Tuple[Tuple, decode.FormattedString]] = []
78
79         for entry, fmt in format_string_entries:
80             result = fmt.format(encoded_message[ENCODED_TOKEN.size:],
81                                 show_errors)
82
83             # Sort competing entries so the most likely matches appear first.
84             # Decoded strings are prioritized by whether they
85             #
86             #   1. decoded all bytes for all arguments without errors,
87             #   2. decoded all data,
88             #   3. have the fewest decoding errors,
89             #   4. decoded the most arguments successfully, or
90             #   5. have the most recent removal date, if they were removed.
91             #
92             # This must match the collision resolution logic in detokenize.cc.
93             score: Tuple = (
94                 all(arg.ok() for arg in result.args) and not result.remaining,
95                 not result.remaining,  # decoded all data
96                 -sum(not arg.ok() for arg in result.args),  # fewest errors
97                 len(result.args),  # decoded the most arguments
98                 entry.date_removed or datetime.max)  # most recently present
99
100             decode_attempts.append((score, result))
101
102         # Sort the attempts by the score so the most likely results are first.
103         decode_attempts.sort(key=lambda value: value[0], reverse=True)
104
105         # Split out the successesful decodes from the failures.
106         for score, result in decode_attempts:
107             if score[0]:
108                 self.successes.append(result)
109             else:
110                 self.failures.append(result)
111
112     def ok(self) -> bool:
113         """True if exactly one string decoded the arguments successfully."""
114         return len(self.successes) == 1
115
116     def matches(self) -> List[decode.FormattedString]:
117         """Returns the strings that matched the token, best matches first."""
118         return self.successes + self.failures
119
120     def best_result(self) -> Optional[decode.FormattedString]:
121         """Returns the string and args for the most likely decoded string."""
122         for string_and_args in self.matches():
123             return string_and_args
124
125         return None
126
127     def error_message(self) -> str:
128         """If detokenization failed, returns a descriptive message."""
129         if self.ok():
130             return ''
131
132         if not self.matches():
133             if self.token is None:
134                 return 'missing token'
135
136             return 'unknown token {:08x}'.format(self.token)
137
138         if len(self.matches()) == 1:
139             return 'decoding failed for {!r}'.format(self.matches()[0].value)
140
141         return '{} matches'.format(len(self.matches()))
142
143     def __str__(self) -> str:
144         """Returns the string for the most likely result."""
145         result = self.best_result()
146         if result:
147             return result[0]
148
149         if self._show_errors:
150             return '<[ERROR: {}|{!r}]>'.format(self.error_message(),
151                                                self.encoded_message)
152         return ''
153
154     def __repr__(self) -> str:
155         if self.ok():
156             message = repr(str(self))
157         else:
158             message = 'ERROR: {}|{!r}'.format(self.error_message(),
159                                               self.encoded_message)
160
161         return '{}({})'.format(type(self).__name__, message)
162
163
164 class _TokenizedFormatString(NamedTuple):
165     entry: tokens.TokenizedStringEntry
166     format: decode.FormatString
167
168
169 class Detokenizer:
170     """Main detokenization class; detokenizes strings and caches results."""
171     def __init__(self, *token_database_or_elf, show_errors: bool = False):
172         """Decodes and detokenizes binary messages.
173
174         Args:
175           *token_database_or_elf: a path or file object for an ELF or CSV
176               database, a tokens.Database, or an elf_reader.Elf
177           show_errors: if True, an error message is used in place of the %
178               conversion specifier when an argument fails to decode
179         """
180         self.database = database.load_token_database(*token_database_or_elf)
181         self.show_errors = show_errors
182
183         # Cache FormatStrings for faster lookup & formatting.
184         self._cache: Dict[int, List[_TokenizedFormatString]] = {}
185
186     def lookup(self, token: int) -> List[_TokenizedFormatString]:
187         """Returns (TokenizedStringEntry, FormatString) list for matches."""
188         try:
189             return self._cache[token]
190         except KeyError:
191             format_strings = [
192                 _TokenizedFormatString(entry, decode.FormatString(str(entry)))
193                 for entry in self.database.token_to_entries[token]
194             ]
195             self._cache[token] = format_strings
196             return format_strings
197
198     def detokenize(self, encoded_message: bytes) -> DetokenizedString:
199         """Decodes and detokenizes a message as a DetokenizedString."""
200         if len(encoded_message) < ENCODED_TOKEN.size:
201             return DetokenizedString(None, (), encoded_message,
202                                      self.show_errors)
203
204         token, = ENCODED_TOKEN.unpack_from(encoded_message)
205         return DetokenizedString(token, self.lookup(token), encoded_message,
206                                  self.show_errors)
207
208
209 class AutoUpdatingDetokenizer:
210     """Loads and updates a detokenizer from database paths."""
211     class _DatabasePath:
212         """Tracks the modified time of a path or file object."""
213         def __init__(self, path):
214             self.path = path if isinstance(path, (str, Path)) else path.name
215             self._modified_time: Optional[float] = self._last_modified_time()
216
217         def updated(self) -> bool:
218             """True if the path has been updated since the last call."""
219             modified_time = self._last_modified_time()
220             if modified_time is None or modified_time == self._modified_time:
221                 return False
222
223             self._modified_time = modified_time
224             return True
225
226         def _last_modified_time(self) -> Optional[float]:
227             try:
228                 return os.path.getmtime(self.path)
229             except FileNotFoundError:
230                 return None
231
232         def load(self) -> tokens.Database:
233             try:
234                 return database.load_token_database(self.path)
235             except FileNotFoundError:
236                 return database.load_token_database()
237
238     def __init__(self,
239                  *paths_or_files,
240                  min_poll_period_s: float = 1.0) -> None:
241         self.paths = tuple(self._DatabasePath(path) for path in paths_or_files)
242         self.min_poll_period_s = min_poll_period_s
243         self._last_checked_time: float = time.time()
244         self._detokenizer = Detokenizer(*(path.load() for path in self.paths))
245
246     def detokenize(self, data: bytes) -> DetokenizedString:
247         """Updates the token database if it has changed, then detokenizes."""
248         if time.time() - self._last_checked_time >= self.min_poll_period_s:
249             self._last_checked_time = time.time()
250
251             if any(path.updated() for path in self.paths):
252                 _LOG.info('Changes detected; reloading token database')
253                 self._detokenizer = Detokenizer(*(path.load()
254                                                   for path in self.paths))
255
256         return self._detokenizer.detokenize(data)
257
258
259 _Detokenizer = Union[Detokenizer, AutoUpdatingDetokenizer]
260
261
262 class PrefixedMessageDecoder:
263     """Parses messages that start with a prefix character from a byte stream."""
264     def __init__(self, prefix: Union[str, bytes], chars: Union[str, bytes]):
265         """Parses prefixed messages.
266
267         Args:
268           prefix: one character that signifies the start of a message
269           chars: characters allowed in a message
270         """
271         self._prefix = prefix.encode() if isinstance(prefix, str) else prefix
272
273         if isinstance(chars, str):
274             chars = chars.encode()
275
276         # Store the valid message bytes as a set of binary strings.
277         self._message_bytes = frozenset(chars[i:i + 1]
278                                         for i in range(len(chars)))
279
280         if len(self._prefix) != 1 or self._prefix in self._message_bytes:
281             raise ValueError(
282                 'Invalid prefix {!r}: the prefix must be a single '
283                 'character that is not a valid message character.'.format(
284                     prefix))
285
286         self.data = bytearray()
287
288     def _read_next(self, fd: BinaryIO) -> Tuple[bytes, int]:
289         """Returns the next character and its index."""
290         char = fd.read(1)
291         index = len(self.data)
292         self.data += char
293         return char, index
294
295     def read_messages(self,
296                       binary_fd: BinaryIO) -> Iterator[Tuple[bool, bytes]]:
297         """Parses prefixed messages; yields (is_message, contents) chunks."""
298         message_start = None
299
300         while True:
301             # This reads the file character-by-character. Non-message characters
302             # are yielded right away; message characters are grouped.
303             char, index = self._read_next(binary_fd)
304
305             # If in a message, keep reading until the message completes.
306             if message_start is not None:
307                 if char in self._message_bytes:
308                     continue
309
310                 yield True, self.data[message_start:index]
311                 message_start = None
312
313             # Handle a non-message character.
314             if not char:
315                 return
316
317             if char == self._prefix:
318                 message_start = index
319             else:
320                 yield False, char
321
322     def transform(self, binary_fd: BinaryIO,
323                   transform: Callable[[bytes], bytes]) -> Iterator[bytes]:
324         """Yields the file with a transformation applied to the messages."""
325         for is_message, chunk in self.read_messages(binary_fd):
326             yield transform(chunk) if is_message else chunk
327
328
329 def _detokenize_prefixed_base64(
330         detokenizer: _Detokenizer, prefix: bytes,
331         recursion: int) -> Callable[[Match[bytes]], bytes]:
332     """Returns a function that decodes prefixed Base64 with the detokenizer."""
333     def decode_and_detokenize(match: Match[bytes]) -> bytes:
334         """Decodes prefixed base64 with the provided detokenizer."""
335         original = match.group(0)
336
337         try:
338             detokenized_string = detokenizer.detokenize(
339                 base64.b64decode(original[1:], validate=True))
340             if detokenized_string.matches():
341                 result = str(detokenized_string).encode()
342
343                 if recursion > 0 and original != result:
344                     result = detokenize_base64(detokenizer, result, prefix,
345                                                recursion - 1)
346
347                 return result
348         except binascii.Error:
349             pass
350
351         return original
352
353     return decode_and_detokenize
354
355
356 BASE64_PREFIX = b'$'
357 DEFAULT_RECURSION = 9
358
359
360 def _base64_message_regex(prefix: bytes) -> Pattern[bytes]:
361     """Returns a regular expression for prefixed base64 tokenized strings."""
362     return re.compile(
363         # Base64 tokenized strings start with the prefix character ($)
364         re.escape(prefix) + (
365             # Tokenized strings contain 0 or more blocks of four Base64 chars.
366             br'(?:[A-Za-z0-9+/\-_]{4})*'
367             # The last block of 4 chars may have one or two padding chars (=).
368             br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'))
369
370
371 def detokenize_base64_live(detokenizer: _Detokenizer,
372                            input_file: BinaryIO,
373                            output: BinaryIO,
374                            prefix: Union[str, bytes] = BASE64_PREFIX,
375                            recursion: int = DEFAULT_RECURSION) -> None:
376     """Reads chars one-at-a-time and decodes messages; SLOW for big files."""
377     prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix
378
379     base64_message = _base64_message_regex(prefix_bytes)
380
381     def transform(data: bytes) -> bytes:
382         return base64_message.sub(
383             _detokenize_prefixed_base64(detokenizer, prefix_bytes, recursion),
384             data)
385
386     for message in PrefixedMessageDecoder(
387             prefix, string.ascii_letters + string.digits + '+/-_=').transform(
388                 input_file, transform):
389         output.write(message)
390
391         # Flush each line to prevent delays when piping between processes.
392         if b'\n' in message:
393             output.flush()
394
395
396 def detokenize_base64_to_file(detokenizer: _Detokenizer,
397                               data: bytes,
398                               output: BinaryIO,
399                               prefix: Union[str, bytes] = BASE64_PREFIX,
400                               recursion: int = DEFAULT_RECURSION) -> None:
401     """Decodes prefixed Base64 messages in data; decodes to an output file."""
402     prefix = prefix.encode() if isinstance(prefix, str) else prefix
403     output.write(
404         _base64_message_regex(prefix).sub(
405             _detokenize_prefixed_base64(detokenizer, prefix, recursion), data))
406
407
408 def detokenize_base64(detokenizer: _Detokenizer,
409                       data: bytes,
410                       prefix: Union[str, bytes] = BASE64_PREFIX,
411                       recursion: int = DEFAULT_RECURSION) -> bytes:
412     """Decodes and replaces prefixed Base64 messages in the provided data.
413
414     Args:
415       detokenizer: the detokenizer with which to decode messages
416       data: the binary data to decode
417       prefix: one-character byte string that signals the start of a message
418       recursion: how many levels to recursively decode
419
420     Returns:
421       copy of the data with all recognized tokens decoded
422     """
423     output = io.BytesIO()
424     detokenize_base64_to_file(detokenizer, data, output, prefix, recursion)
425     return output.getvalue()
426
427
428 def _follow_and_detokenize_file(detokenizer: _Detokenizer,
429                                 file: BinaryIO,
430                                 output: BinaryIO,
431                                 prefix: Union[str, bytes],
432                                 poll_period_s: float = 0.01) -> None:
433     """Polls a file to detokenize it and any appended data."""
434
435     try:
436         while True:
437             data = file.read()
438             if data:
439                 detokenize_base64_to_file(detokenizer, data, output, prefix)
440                 output.flush()
441             else:
442                 time.sleep(poll_period_s)
443     except KeyboardInterrupt:
444         pass
445
446
447 def _handle_base64(databases, input_file: BinaryIO, output: BinaryIO,
448                    prefix: str, show_errors: bool, follow: bool) -> None:
449     """Handles the base64 command line option."""
450     # argparse.FileType doesn't correctly handle - for binary files.
451     if input_file is sys.stdin:
452         input_file = sys.stdin.buffer
453
454     if output is sys.stdout:
455         output = sys.stdout.buffer
456
457     detokenizer = Detokenizer(tokens.Database.merged(*databases),
458                               show_errors=show_errors)
459
460     if follow:
461         _follow_and_detokenize_file(detokenizer, input_file, output, prefix)
462     elif input_file.seekable():
463         # Process seekable files all at once, which is MUCH faster.
464         detokenize_base64_to_file(detokenizer, input_file.read(), output,
465                                   prefix)
466     else:
467         # For non-seekable inputs (e.g. pipes), read one character at a time.
468         detokenize_base64_live(detokenizer, input_file, output, prefix)
469
470
471 def _parse_args() -> argparse.Namespace:
472     """Parses and return command line arguments."""
473
474     parser = argparse.ArgumentParser(
475         description=__doc__,
476         formatter_class=argparse.RawDescriptionHelpFormatter)
477     parser.set_defaults(handler=lambda **_: parser.print_help())
478
479     subparsers = parser.add_subparsers(help='Encoding of the input.')
480
481     base64_help = 'Detokenize Base64-encoded data from a file or stdin.'
482     subparser = subparsers.add_parser(
483         'base64',
484         description=base64_help,
485         parents=[database.token_databases_parser()],
486         help=base64_help)
487     subparser.set_defaults(handler=_handle_base64)
488     subparser.add_argument(
489         '-i',
490         '--input',
491         dest='input_file',
492         type=argparse.FileType('rb'),
493         default=sys.stdin.buffer,
494         help='The file from which to read; provide - or omit for stdin.')
495     subparser.add_argument(
496         '-f',
497         '--follow',
498         action='store_true',
499         help=('Detokenize data appended to input_file as it grows; similar to '
500               'tail -f.'))
501     subparser.add_argument('-o',
502                            '--output',
503                            type=argparse.FileType('wb'),
504                            default=sys.stdout.buffer,
505                            help=('The file to which to write the output; '
506                                  'provide - or omit for stdout.'))
507     subparser.add_argument(
508         '-p',
509         '--prefix',
510         default=BASE64_PREFIX,
511         help=('The one-character prefix that signals the start of a '
512               'Base64-encoded message. (default: $)'))
513     subparser.add_argument(
514         '-s',
515         '--show_errors',
516         action='store_true',
517         help=('Show error messages instead of conversion specifiers when '
518               'arguments cannot be decoded.'))
519
520     return parser.parse_args()
521
522
523 def main() -> int:
524     args = _parse_args()
525
526     handler = args.handler
527     del args.handler
528
529     handler(**vars(args))
530     return 0
531
532
533 if __name__ == '__main__':
534     if sys.version_info[0] < 3:
535         sys.exit('ERROR: The detokenizer command line tools require Python 3.')
536     sys.exit(main())