Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_tokenizer / py / tokens_test.py
1 #!/usr/bin/env python3
2 # Copyright 2020 The Pigweed Authors
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
5 # use this file except in compliance with the License. You may obtain a copy of
6 # the License at
7 #
8 #     https://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations under
14 # the License.
15 """Tests for the tokens module."""
16
17 import datetime
18 import io
19 import logging
20 from pathlib import Path
21 import tempfile
22 from typing import Iterator
23 import unittest
24
25 from pw_tokenizer import tokens
26 from pw_tokenizer.tokens import default_hash, _LOG
27
28 CSV_DATABASE = '''\
29 00000000,2019-06-10,""
30 141c35d5,          ,"The answer: ""%s"""
31 2db1515f,          ,"%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c"
32 2e668cd6,2019-06-11,"Jello, world!"
33 31631781,          ,"%d"
34 61fd1e26,          ,"%ld"
35 68ab92da,          ,"%s there are %x (%.2f) of them%c"
36 7b940e2a,          ,"Hello %s! %hd %e"
37 851beeb6,          ,"%u %d"
38 881436a0,          ,"The answer is: %s"
39 ad002c97,          ,"%llx"
40 b3653e13,2019-06-12,"Jello!"
41 b912567b,          ,"%x%lld%1.2f%s"
42 cc6d3131,2020-01-01,"Jello?"
43 e13b0f94,          ,"%llu"
44 e65aefef,2019-06-10,"Won't fit : %s%d"
45 '''
46
47 # The date 2019-06-10 is 07E3-06-0A in hex. In database order, it's 0A 06 E3 07.
48 BINARY_DATABASE = (
49     b'TOKENS\x00\x00\x10\x00\x00\x00\0\0\0\0'  # header (0x10 entries)
50     b'\x00\x00\x00\x00\x0a\x06\xe3\x07'  # 0x01
51     b'\xd5\x35\x1c\x14\xff\xff\xff\xff'  # 0x02
52     b'\x5f\x51\xb1\x2d\xff\xff\xff\xff'  # 0x03
53     b'\xd6\x8c\x66\x2e\x0b\x06\xe3\x07'  # 0x04
54     b'\x81\x17\x63\x31\xff\xff\xff\xff'  # 0x05
55     b'\x26\x1e\xfd\x61\xff\xff\xff\xff'  # 0x06
56     b'\xda\x92\xab\x68\xff\xff\xff\xff'  # 0x07
57     b'\x2a\x0e\x94\x7b\xff\xff\xff\xff'  # 0x08
58     b'\xb6\xee\x1b\x85\xff\xff\xff\xff'  # 0x09
59     b'\xa0\x36\x14\x88\xff\xff\xff\xff'  # 0x0a
60     b'\x97\x2c\x00\xad\xff\xff\xff\xff'  # 0x0b
61     b'\x13\x3e\x65\xb3\x0c\x06\xe3\x07'  # 0x0c
62     b'\x7b\x56\x12\xb9\xff\xff\xff\xff'  # 0x0d
63     b'\x31\x31\x6d\xcc\x01\x01\xe4\x07'  # 0x0e
64     b'\x94\x0f\x3b\xe1\xff\xff\xff\xff'  # 0x0f
65     b'\xef\xef\x5a\xe6\x0a\x06\xe3\x07'  # 0x10
66     b'\x00'
67     b'The answer: "%s"\x00'
68     b'%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c\x00'
69     b'Jello, world!\x00'
70     b'%d\x00'
71     b'%ld\x00'
72     b'%s there are %x (%.2f) of them%c\x00'
73     b'Hello %s! %hd %e\x00'
74     b'%u %d\x00'
75     b'The answer is: %s\x00'
76     b'%llx\x00'
77     b'Jello!\x00'
78     b'%x%lld%1.2f%s\x00'
79     b'Jello?\x00'
80     b'%llu\x00'
81     b'Won\'t fit : %s%d\x00')
82
83 INVALID_CSV = """\
84 1,,"Whoa there!"
85 2,this is totally invalid,"Whoa there!"
86 3,,"This one's OK"
87 ,,"Also broken"
88 5,1845-2-2,"I'm %s fine"
89 6,"Missing fields"
90 """
91
92
93 def read_db_from_csv(csv_str: str) -> tokens.Database:
94     with io.StringIO(csv_str) as csv_db:
95         return tokens.Database(tokens.parse_csv(csv_db))
96
97
98 def _entries(*strings: str) -> Iterator[tokens.TokenizedStringEntry]:
99     for string in strings:
100         yield tokens.TokenizedStringEntry(default_hash(string), string)
101
102
103 class TokenDatabaseTest(unittest.TestCase):
104     """Tests the token database class."""
105     def test_csv(self):
106         db = read_db_from_csv(CSV_DATABASE)
107         self.assertEqual(str(db), CSV_DATABASE)
108
109         db = read_db_from_csv('')
110         self.assertEqual(str(db), '')
111
112     def test_csv_formatting(self):
113         db = read_db_from_csv('')
114         self.assertEqual(str(db), '')
115
116         db = read_db_from_csv('abc123,2048-4-1,Fake string\n')
117         self.assertEqual(str(db), '00abc123,2048-04-01,"Fake string"\n')
118
119         db = read_db_from_csv('1,1990-01-01,"Quotes"""\n'
120                               '0,1990-02-01,"Commas,"",,"\n')
121         self.assertEqual(str(db), ('00000000,1990-02-01,"Commas,"",,"\n'
122                                    '00000001,1990-01-01,"Quotes"""\n'))
123
124     def test_bad_csv(self):
125         with self.assertLogs(_LOG, logging.ERROR) as logs:
126             db = read_db_from_csv(INVALID_CSV)
127
128         self.assertGreaterEqual(len(logs.output), 3)
129         self.assertEqual(len(db.token_to_entries), 3)
130
131         self.assertEqual(db.token_to_entries[1][0].string, 'Whoa there!')
132         self.assertFalse(db.token_to_entries[2])
133         self.assertEqual(db.token_to_entries[3][0].string, "This one's OK")
134         self.assertFalse(db.token_to_entries[4])
135         self.assertEqual(db.token_to_entries[5][0].string, "I'm %s fine")
136         self.assertFalse(db.token_to_entries[6])
137
138     def test_lookup(self):
139         db = read_db_from_csv(CSV_DATABASE)
140         self.assertEqual(db.token_to_entries[0x9999], [])
141
142         matches = db.token_to_entries[0x2e668cd6]
143         self.assertEqual(len(matches), 1)
144         jello = matches[0]
145
146         self.assertEqual(jello.token, 0x2e668cd6)
147         self.assertEqual(jello.string, 'Jello, world!')
148         self.assertEqual(jello.date_removed, datetime.datetime(2019, 6, 11))
149
150         matches = db.token_to_entries[0xe13b0f94]
151         self.assertEqual(len(matches), 1)
152         llu = matches[0]
153         self.assertEqual(llu.token, 0xe13b0f94)
154         self.assertEqual(llu.string, '%llu')
155         self.assertIsNone(llu.date_removed)
156
157         answer, = db.token_to_entries[0x141c35d5]
158         self.assertEqual(answer.string, 'The answer: "%s"')
159
160     def test_collisions(self):
161         hash_1 = tokens.pw_tokenizer_65599_fixed_length_hash('o000', 96)
162         hash_2 = tokens.pw_tokenizer_65599_fixed_length_hash('0Q1Q', 96)
163         self.assertEqual(hash_1, hash_2)
164
165         db = tokens.Database.from_strings(['o000', '0Q1Q'])
166
167         self.assertEqual(len(db.token_to_entries[hash_1]), 2)
168         self.assertCountEqual(
169             [entry.string for entry in db.token_to_entries[hash_1]],
170             ['o000', '0Q1Q'])
171
172     def test_purge(self):
173         db = read_db_from_csv(CSV_DATABASE)
174         original_length = len(db.token_to_entries)
175
176         self.assertEqual(db.token_to_entries[0][0].string, '')
177         self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d')
178         self.assertEqual(db.token_to_entries[0x2e668cd6][0].string,
179                          'Jello, world!')
180         self.assertEqual(db.token_to_entries[0xb3653e13][0].string, 'Jello!')
181         self.assertEqual(db.token_to_entries[0xcc6d3131][0].string, 'Jello?')
182         self.assertEqual(db.token_to_entries[0xe65aefef][0].string,
183                          "Won't fit : %s%d")
184
185         db.purge(datetime.datetime(2019, 6, 11))
186         self.assertLess(len(db.token_to_entries), original_length)
187
188         self.assertFalse(db.token_to_entries[0])
189         self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d')
190         self.assertFalse(db.token_to_entries[0x2e668cd6])
191         self.assertEqual(db.token_to_entries[0xb3653e13][0].string, 'Jello!')
192         self.assertEqual(db.token_to_entries[0xcc6d3131][0].string, 'Jello?')
193         self.assertFalse(db.token_to_entries[0xe65aefef])
194
195     def test_merge(self):
196         """Tests the tokens.Database merge method."""
197
198         db = tokens.Database()
199
200         # Test basic merging into an empty database.
201         db.merge(
202             tokens.Database([
203                 tokens.TokenizedStringEntry(
204                     1, 'one', date_removed=datetime.datetime.min),
205                 tokens.TokenizedStringEntry(
206                     2, 'two', date_removed=datetime.datetime.min),
207             ]))
208         self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'})
209         self.assertEqual(db.token_to_entries[1][0].date_removed,
210                          datetime.datetime.min)
211         self.assertEqual(db.token_to_entries[2][0].date_removed,
212                          datetime.datetime.min)
213
214         # Test merging in an entry with a removal date.
215         db.merge(
216             tokens.Database([
217                 tokens.TokenizedStringEntry(3, 'three'),
218                 tokens.TokenizedStringEntry(
219                     4, 'four', date_removed=datetime.datetime.min),
220             ]))
221         self.assertEqual({str(e)
222                           for e in db.entries()},
223                          {'one', 'two', 'three', 'four'})
224         self.assertIsNone(db.token_to_entries[3][0].date_removed)
225         self.assertEqual(db.token_to_entries[4][0].date_removed,
226                          datetime.datetime.min)
227
228         # Test merging in one entry.
229         db.merge(tokens.Database([
230             tokens.TokenizedStringEntry(5, 'five'),
231         ]))
232         self.assertEqual({str(e)
233                           for e in db.entries()},
234                          {'one', 'two', 'three', 'four', 'five'})
235         self.assertEqual(db.token_to_entries[4][0].date_removed,
236                          datetime.datetime.min)
237         self.assertIsNone(db.token_to_entries[5][0].date_removed)
238
239         # Merge in repeated entries different removal dates.
240         db.merge(
241             tokens.Database([
242                 tokens.TokenizedStringEntry(
243                     4, 'four', date_removed=datetime.datetime.max),
244                 tokens.TokenizedStringEntry(
245                     5, 'five', date_removed=datetime.datetime.max),
246             ]))
247         self.assertEqual(len(db.entries()), 5)
248         self.assertEqual({str(e)
249                           for e in db.entries()},
250                          {'one', 'two', 'three', 'four', 'five'})
251         self.assertEqual(db.token_to_entries[4][0].date_removed,
252                          datetime.datetime.max)
253         self.assertIsNone(db.token_to_entries[5][0].date_removed)
254
255         # Merge in the same repeated entries now without removal dates.
256         db.merge(
257             tokens.Database([
258                 tokens.TokenizedStringEntry(4, 'four'),
259                 tokens.TokenizedStringEntry(5, 'five')
260             ]))
261         self.assertEqual(len(db.entries()), 5)
262         self.assertEqual({str(e)
263                           for e in db.entries()},
264                          {'one', 'two', 'three', 'four', 'five'})
265         self.assertIsNone(db.token_to_entries[4][0].date_removed)
266         self.assertIsNone(db.token_to_entries[5][0].date_removed)
267
268         # Merge in an empty databsse.
269         db.merge(tokens.Database([]))
270         self.assertEqual({str(e)
271                           for e in db.entries()},
272                          {'one', 'two', 'three', 'four', 'five'})
273
274     def test_merge_multiple_datbases_in_one_call(self):
275         """Tests the merge and merged methods with multiple databases."""
276         db = tokens.Database.merged(
277             tokens.Database([
278                 tokens.TokenizedStringEntry(1,
279                                             'one',
280                                             date_removed=datetime.datetime.max)
281             ]),
282             tokens.Database([
283                 tokens.TokenizedStringEntry(2,
284                                             'two',
285                                             date_removed=datetime.datetime.min)
286             ]),
287             tokens.Database([
288                 tokens.TokenizedStringEntry(1,
289                                             'one',
290                                             date_removed=datetime.datetime.min)
291             ]))
292         self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'})
293
294         db.merge(
295             tokens.Database([
296                 tokens.TokenizedStringEntry(4,
297                                             'four',
298                                             date_removed=datetime.datetime.max)
299             ]),
300             tokens.Database([
301                 tokens.TokenizedStringEntry(2,
302                                             'two',
303                                             date_removed=datetime.datetime.max)
304             ]),
305             tokens.Database([
306                 tokens.TokenizedStringEntry(3,
307                                             'three',
308                                             date_removed=datetime.datetime.min)
309             ]))
310         self.assertEqual({str(e)
311                           for e in db.entries()},
312                          {'one', 'two', 'three', 'four'})
313
314     def test_entry_counts(self):
315         self.assertEqual(len(CSV_DATABASE.splitlines()), 16)
316
317         db = read_db_from_csv(CSV_DATABASE)
318         self.assertEqual(len(db.entries()), 16)
319         self.assertEqual(len(db.token_to_entries), 16)
320
321         # Add two strings with the same hash.
322         db.add(_entries('o000', '0Q1Q'))
323
324         self.assertEqual(len(db.entries()), 18)
325         self.assertEqual(len(db.token_to_entries), 17)
326
327     def test_mark_removals(self):
328         """Tests that date_removed field is set by mark_removals."""
329         db = tokens.Database.from_strings(
330             ['MILK', 'apples', 'oranges', 'CHEESE', 'pears'])
331
332         self.assertTrue(
333             all(entry.date_removed is None for entry in db.entries()))
334         date_1 = datetime.datetime(1, 2, 3)
335
336         db.mark_removals(_entries('apples', 'oranges', 'pears'), date_1)
337
338         self.assertEqual(
339             db.token_to_entries[default_hash('MILK')][0].date_removed, date_1)
340         self.assertEqual(
341             db.token_to_entries[default_hash('CHEESE')][0].date_removed,
342             date_1)
343
344         now = datetime.datetime.now()
345         db.mark_removals(_entries('MILK', 'CHEESE', 'pears'))
346
347         # New strings are not added or re-added in mark_removed().
348         self.assertGreaterEqual(
349             db.token_to_entries[default_hash('MILK')][0].date_removed, date_1)
350         self.assertGreaterEqual(
351             db.token_to_entries[default_hash('CHEESE')][0].date_removed,
352             date_1)
353
354         # These strings were removed.
355         self.assertGreaterEqual(
356             db.token_to_entries[default_hash('apples')][0].date_removed, now)
357         self.assertGreaterEqual(
358             db.token_to_entries[default_hash('oranges')][0].date_removed, now)
359         self.assertIsNone(
360             db.token_to_entries[default_hash('pears')][0].date_removed)
361
362     def test_add(self):
363         db = tokens.Database()
364         db.add(_entries('MILK', 'apples'))
365         self.assertEqual({e.string for e in db.entries()}, {'MILK', 'apples'})
366
367         db.add(_entries('oranges', 'CHEESE', 'pears'))
368         self.assertEqual(len(db.entries()), 5)
369
370         db.add(_entries('MILK', 'apples', 'only this one is new'))
371         self.assertEqual(len(db.entries()), 6)
372
373         db.add(_entries('MILK'))
374         self.assertEqual({e.string
375                           for e in db.entries()}, {
376                               'MILK', 'apples', 'oranges', 'CHEESE', 'pears',
377                               'only this one is new'
378                           })
379
380     def test_binary_format_write(self):
381         db = read_db_from_csv(CSV_DATABASE)
382
383         with io.BytesIO() as fd:
384             tokens.write_binary(db, fd)
385             binary_db = fd.getvalue()
386
387         self.assertEqual(BINARY_DATABASE, binary_db)
388
389     def test_binary_format_parse(self):
390         with io.BytesIO(BINARY_DATABASE) as binary_db:
391             db = tokens.Database(tokens.parse_binary(binary_db))
392
393         self.assertEqual(str(db), CSV_DATABASE)
394
395
396 class TestDatabaseFile(unittest.TestCase):
397     """Tests the DatabaseFile class."""
398     def setUp(self):
399         file = tempfile.NamedTemporaryFile(delete=False)
400         file.close()
401         self._path = Path(file.name)
402
403     def tearDown(self):
404         self._path.unlink()
405
406     def test_update_csv_file(self):
407         self._path.write_text(CSV_DATABASE)
408         db = tokens.DatabaseFile(self._path)
409         self.assertEqual(str(db), CSV_DATABASE)
410
411         db.add([tokens.TokenizedStringEntry(0xffffffff, 'New entry!')])
412
413         db.write_to_file()
414
415         self.assertEqual(self._path.read_text(),
416                          CSV_DATABASE + 'ffffffff,          ,"New entry!"\n')
417
418     def test_csv_file_too_short_raises_exception(self):
419         self._path.write_text('1234')
420
421         with self.assertRaises(tokens.DatabaseFormatError):
422             tokens.DatabaseFile(self._path)
423
424     def test_csv_invalid_format_raises_exception(self):
425         self._path.write_text('MK34567890')
426
427         with self.assertRaises(tokens.DatabaseFormatError):
428             tokens.DatabaseFile(self._path)
429
430     def test_csv_not_utf8(self):
431         self._path.write_bytes(b'\x80' * 20)
432
433         with self.assertRaises(tokens.DatabaseFormatError):
434             tokens.DatabaseFile(self._path)
435
436
437 class TestFilter(unittest.TestCase):
438     """Tests the filtering functionality."""
439     def setUp(self):
440         self.db = tokens.Database([
441             tokens.TokenizedStringEntry(1, 'Luke'),
442             tokens.TokenizedStringEntry(2, 'Leia'),
443             tokens.TokenizedStringEntry(2, 'Darth Vader'),
444             tokens.TokenizedStringEntry(2, 'Emperor Palpatine'),
445             tokens.TokenizedStringEntry(3, 'Han'),
446             tokens.TokenizedStringEntry(4, 'Chewbacca'),
447             tokens.TokenizedStringEntry(5, 'Darth Maul'),
448             tokens.TokenizedStringEntry(6, 'Han Solo'),
449         ])
450
451     def test_filter_include_single_regex(self):
452         self.db.filter(include=[' '])  # anything with a space
453         self.assertEqual(
454             set(e.string for e in self.db.entries()),
455             {'Darth Vader', 'Emperor Palpatine', 'Darth Maul', 'Han Solo'})
456
457     def test_filter_include_multiple_regexes(self):
458         self.db.filter(include=['Darth', 'cc', '^Han$'])
459         self.assertEqual(set(e.string for e in self.db.entries()),
460                          {'Darth Vader', 'Darth Maul', 'Han', 'Chewbacca'})
461
462     def test_filter_include_no_matches(self):
463         self.db.filter(include=['Gandalf'])
464         self.assertFalse(self.db.entries())
465
466     def test_filter_exclude_single_regex(self):
467         self.db.filter(exclude=['^[^L]'])
468         self.assertEqual(set(e.string for e in self.db.entries()),
469                          {'Luke', 'Leia'})
470
471     def test_filter_exclude_multiple_regexes(self):
472         self.db.filter(exclude=[' ', 'Han', 'Chewbacca'])
473         self.assertEqual(set(e.string for e in self.db.entries()),
474                          {'Luke', 'Leia'})
475
476     def test_filter_exclude_no_matches(self):
477         self.db.filter(exclude=['.*'])
478         self.assertFalse(self.db.entries())
479
480     def test_filter_include_and_exclude(self):
481         self.db.filter(include=[' '], exclude=['Darth', 'Emperor'])
482         self.assertEqual(set(e.string for e in self.db.entries()),
483                          {'Han Solo'})
484
485     def test_filter_neither_include_nor_exclude(self):
486         self.db.filter()
487         self.assertEqual(
488             set(e.string for e in self.db.entries()), {
489                 'Luke', 'Leia', 'Darth Vader', 'Emperor Palpatine', 'Han',
490                 'Chewbacca', 'Darth Maul', 'Han Solo'
491             })
492
493
494 if __name__ == '__main__':
495     unittest.main()