darray: check integer overflow
[platform/upstream/libxkbcommon.git] / scripts / perfect_hash.py
1 # Derived from: https://github.com/ilanschnell/perfect-hash
2 # Commit: 6b7dd80a525dbd4349ea2c69f04a9c96f3c2fd54
3
4 # BSD 3-Clause License
5 #
6 # Copyright (c) 2019 - 2021, Ilan Schnell
7 # All rights reserved.
8 #
9 # Redistribution and use in source and binary forms, with or without
10 # modification, are permitted provided that the following conditions are met:
11 #     * Redistributions of source code must retain the above copyright
12 #       notice, this list of conditions and the following disclaimer.
13 #     * Redistributions in binary form must reproduce the above copyright
14 #       notice, this list of conditions and the following disclaimer in the
15 #       documentation and/or other materials provided with the distribution.
16 #     * Neither the name of the Ilan Schnell nor the
17 #       names of its contributors may be used to endorse or promote products
18 #       derived from this software without specific prior written permission.
19 #
20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
21 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
22 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 # DISCLAIMED. IN NO EVENT SHALL ILAN SCHNELL BE LIABLE FOR ANY
24 # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
25 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
26 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
27 # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 """
32 Generate a minimal perfect hash function for the keys in a file,
33 desired hash values may be specified within this file as well.
34 A given code template is filled with parameters, such that the
35 output is code which implements the hash function.
36 Templates can easily be constructed for any programming language.
37
38 The code is based on an a program A.M. Kuchling wrote:
39 http://www.amk.ca/python/code/perfect-hash
40
41 The algorithm the program uses is described in the paper
42 'Optimal algorithms for minimal perfect hashing',
43 Z. J. Czech, G. Havas and B.S. Majewski.
44 http://citeseer.ist.psu.edu/122364.html
45
46 The algorithm works like this:
47
48 1.  You have K keys, that you want to perfectly hash against some
49     desired hash values.
50
51 2.  Choose a number N larger than K.  This is the number of
52     vertices in a graph G, and also the size of the resulting table G.
53
54 3.  Pick two random hash functions f1, f2, that return values from 0..N-1.
55
56 4.  Now, for all keys, you draw an edge between vertices f1(key) and f2(key)
57     of the graph G, and associate the desired hash value with that edge.
58
59 5.  If G is cyclic, go back to step 2.
60
61 6.  Assign values to each vertex such that, for each edge, you can add
62     the values for the two vertices and get the desired (hash) value
63     for that edge.  This task is easy, because the graph is acyclic.
64     This is done by picking a vertex, and assigning it a value of 0.
65     Then do a depth-first search, assigning values to new vertices so that
66     they sum up properly.
67
68 7.  f1, f2, and vertex values of G now make up a perfect hash function.
69
70
71 For simplicity, the implementation of the algorithm combines steps 5 and 6.
72 That is, we check for loops in G and assign the vertex values in one procedure.
73 If this procedure succeeds, G is acyclic and the vertex values are assigned.
74 If the procedure fails, G is cyclic, and we go back to step 2, replacing G
75 with a new graph, and thereby discarding the vertex values from the failed
76 attempt.
77 """
78 from __future__ import absolute_import, division, print_function
79
80 import sys
81 import random
82 import string
83 import subprocess
84 import shutil
85 import tempfile
86 from collections import defaultdict
87 from os.path import join
88
89 if sys.version_info[0] == 2:
90     from cStringIO import StringIO
91 else:
92     from io import StringIO
93
94
95 __version__ = "0.4.2"
96
97
98 verbose = False
99 trials = 150
100
101
102 class Graph(object):
103     """
104     Implements a graph with 'N' vertices.  First, you connect the graph with
105     edges, which have a desired value associated.  Then the vertex values
106     are assigned, which will fail if the graph is cyclic.  The vertex values
107     are assigned such that the two values corresponding to an edge add up to
108     the desired edge value (mod N).
109     """
110
111     def __init__(self, N):
112         self.N = N  # number of vertices
113
114         # maps a vertex number to the list of tuples (vertex, edge value)
115         # to which it is connected by edges.
116         self.adjacent = defaultdict(list)
117
118     def connect(self, vertex1, vertex2, edge_value):
119         """
120         Connect 'vertex1' and 'vertex2' with an edge, with associated
121         value 'value'
122         """
123         # Add vertices to each other's adjacent list
124         self.adjacent[vertex1].append((vertex2, edge_value))
125         self.adjacent[vertex2].append((vertex1, edge_value))
126
127     def assign_vertex_values(self):
128         """
129         Try to assign the vertex values, such that, for each edge, you can
130         add the values for the two vertices involved and get the desired
131         value for that edge, i.e. the desired hash key.
132         This will fail when the graph is cyclic.
133
134         This is done by a Depth-First Search of the graph.  If the search
135         finds a vertex that was visited before, there's a loop and False is
136         returned immediately, i.e. the assignment is terminated.
137         On success (when the graph is acyclic) True is returned.
138         """
139         self.vertex_values = self.N * [-1]  # -1 means unassigned
140
141         visited = self.N * [False]
142
143         # Loop over all vertices, taking unvisited ones as roots.
144         for root in range(self.N):
145             if visited[root]:
146                 continue
147
148             # explore tree starting at 'root'
149             self.vertex_values[root] = 0  # set arbitrarily to zero
150
151             # Stack of vertices to visit, a list of tuples (parent, vertex)
152             tovisit = [(None, root)]
153             while tovisit:
154                 parent, vertex = tovisit.pop()
155                 visited[vertex] = True
156
157                 # Loop over adjacent vertices, but skip the vertex we arrived
158                 # here from the first time it is encountered.
159                 skip = True
160                 for neighbor, edge_value in self.adjacent[vertex]:
161                     if skip and neighbor == parent:
162                         skip = False
163                         continue
164
165                     if visited[neighbor]:
166                         # We visited here before, so the graph is cyclic.
167                         return False
168
169                     tovisit.append((vertex, neighbor))
170
171                     # Set new vertex's value to the desired edge value,
172                     # minus the value of the vertex we came here from.
173                     self.vertex_values[neighbor] = (
174                         edge_value - self.vertex_values[vertex]
175                     ) % self.N
176
177         # check if all vertices have a valid value
178         for vertex in range(self.N):
179             assert self.vertex_values[vertex] >= 0
180
181         # We got though, so the graph is acyclic,
182         # and all values are now assigned.
183         return True
184
185
186 class StrSaltHash(object):
187     """
188     Random hash function generator.
189     Simple byte level hashing: each byte is multiplied to another byte from
190     a random string of characters, summed up, and finally modulo NG is
191     taken.
192     """
193
194     chars = string.ascii_letters + string.digits
195
196     def __init__(self, N):
197         self.N = N
198         self.salt = ""
199
200     def __call__(self, key):
201         # XXX: xkbcommon modification: make the salt length a power of 2
202         #      so that the % operation in the hash is fast.
203         while len(self.salt) < max(len(key), 32):  # add more salt as necessary
204             self.salt += random.choice(self.chars)
205
206         return sum(ord(self.salt[i]) * ord(c) for i, c in enumerate(key)) % self.N
207
208     template = """
209 def hash_f(key, T):
210     return sum(ord(T[i % $NS]) * ord(c) for i, c in enumerate(key)) % $NG
211
212 def perfect_hash(key):
213     return (G[hash_f(key, "$S1")] +
214             G[hash_f(key, "$S2")]) % $NG
215 """
216
217
218 class IntSaltHash(object):
219     """
220     Random hash function generator.
221     Simple byte level hashing, each byte is multiplied in sequence to a table
222     containing random numbers, summed tp, and finally modulo NG is taken.
223     """
224
225     def __init__(self, N):
226         self.N = N
227         self.salt = []
228
229     def __call__(self, key):
230         while len(self.salt) < len(key):  # add more salt as necessary
231             self.salt.append(random.randint(1, self.N - 1))
232
233         return sum(self.salt[i] * ord(c) for i, c in enumerate(key)) % self.N
234
235     template = """
236 S1 = [$S1]
237 S2 = [$S2]
238 assert len(S1) == len(S2) == $NS
239
240 def hash_f(key, T):
241     return sum(T[i % $NS] * ord(c) for i, c in enumerate(key)) % $NG
242
243 def perfect_hash(key):
244     return (G[hash_f(key, S1)] + G[hash_f(key, S2)]) % $NG
245 """
246
247
248 def builtin_template(Hash):
249     return (
250         """\
251 # =======================================================================
252 # ================= Python code for perfect hash function ===============
253 # =======================================================================
254
255 G = [$G]
256 """
257         + Hash.template
258         + """
259 # ============================ Sanity check =============================
260
261 K = [$K]
262 assert len(K) == $NK
263
264 for h, k in enumerate(K):
265     assert perfect_hash(k) == h
266 """
267     )
268
269
270 class TooManyInterationsError(Exception):
271     pass
272
273
274 def generate_hash(keys, Hash=StrSaltHash):
275     """
276     Return hash functions f1 and f2, and G for a perfect minimal hash.
277     Input is an iterable of 'keys', whos indicies are the desired hash values.
278     'Hash' is a random hash function generator, that means Hash(N) returns a
279     returns a random hash function which returns hash values from 0..N-1.
280     """
281     if not isinstance(keys, (list, tuple)):
282         raise TypeError("list or tuple expected")
283     NK = len(keys)
284     if NK != len(set(keys)):
285         raise ValueError("duplicate keys")
286     for key in keys:
287         if not isinstance(key, str):
288             raise TypeError("key a not string: %r" % key)
289     if NK > 10000 and Hash == StrSaltHash:
290         print(
291             """\
292 WARNING: You have %d keys.
293          Using --hft=1 is likely to fail for so many keys.
294          Please use --hft=2 instead.
295 """
296             % NK
297         )
298
299     # the number of vertices in the graph G
300     NG = NK + 1
301     if verbose:
302         print("NG = %d" % NG)
303
304     trial = 0  # Number of trial graphs so far
305     while True:
306         if (trial % trials) == 0:  # trials failures, increase NG slightly
307             if trial > 0:
308                 NG = max(NG + 1, int(1.05 * NG))
309             if verbose:
310                 sys.stdout.write("\nGenerating graphs NG = %d " % NG)
311         trial += 1
312
313         if NG > 100 * (NK + 1):
314             raise TooManyInterationsError("%d keys" % NK)
315
316         if verbose:
317             sys.stdout.write(".")
318             sys.stdout.flush()
319
320         G = Graph(NG)  # Create graph with NG vertices
321         f1 = Hash(NG)  # Create 2 random hash functions
322         f2 = Hash(NG)
323
324         # Connect vertices given by the values of the two hash functions
325         # for each key.  Associate the desired hash value with each edge.
326         for hashval, key in enumerate(keys):
327             G.connect(f1(key), f2(key), hashval)
328
329         # Try to assign the vertex values.  This will fail when the graph
330         # is cyclic.  But when the graph is acyclic it will succeed and we
331         # break out, because we're done.
332         if G.assign_vertex_values():
333             break
334
335     if verbose:
336         print("\nAcyclic graph found after %d trials." % trial)
337         print("NG = %d" % NG)
338
339     # Sanity check the result by actually verifying that all the keys
340     # hash to the right value.
341     for hashval, key in enumerate(keys):
342         assert hashval == (G.vertex_values[f1(key)] + G.vertex_values[f2(key)]) % NG
343
344     if verbose:
345         print("OK")
346
347     return f1, f2, G.vertex_values
348
349
350 class Format(object):
351     def __init__(self, width=76, indent=4, delimiter=", "):
352         self.width = width
353         self.indent = indent
354         self.delimiter = delimiter
355
356     def print_format(self):
357         print("Format options:")
358         for name in "width", "indent", "delimiter":
359             print("  %s: %r" % (name, getattr(self, name)))
360
361     def __call__(self, data, quote=False):
362         if not isinstance(data, (list, tuple)):
363             return str(data)
364
365         lendel = len(self.delimiter)
366         aux = StringIO()
367         pos = 20
368         for i, elt in enumerate(data):
369             last = bool(i == len(data) - 1)
370
371             s = ('"%s"' if quote else "%s") % elt
372
373             if pos + len(s) + lendel > self.width:
374                 aux.write("\n" + (self.indent * " "))
375                 pos = self.indent
376
377             aux.write(s)
378             pos += len(s)
379             if not last:
380                 aux.write(self.delimiter)
381                 pos += lendel
382
383         return "\n".join(l.rstrip() for l in aux.getvalue().split("\n"))
384
385
386 def generate_code(keys, Hash=StrSaltHash, template=None, options=None):
387     """
388     Takes a list of key value pairs and inserts the generated parameter
389     lists into the 'template' string.  'Hash' is the random hash function
390     generator, and the optional keywords are formating options.
391     The return value is the substituted code template.
392     """
393     f1, f2, G = generate_hash(keys, Hash)
394
395     assert f1.N == f2.N == len(G)
396     try:
397         salt_len = len(f1.salt)
398         assert salt_len == len(f2.salt)
399     except TypeError:
400         salt_len = None
401
402     if template is None:
403         template = builtin_template(Hash)
404
405     if options is None:
406         fmt = Format()
407     else:
408         fmt = Format(
409             width=options.width, indent=options.indent, delimiter=options.delimiter
410         )
411
412     if verbose:
413         fmt.print_format()
414
415     return string.Template(template).substitute(
416         NS=salt_len,
417         S1=fmt(f1.salt),
418         S2=fmt(f2.salt),
419         NG=len(G),
420         G=fmt(G),
421         NK=len(keys),
422         K=fmt(list(keys), quote=True),
423     )
424
425
426 def read_table(filename, options):
427     """
428     Reads keys and desired hash value pairs from a file.  If no column
429     for the hash value is specified, a sequence of hash values is generated,
430     from 0 to N-1, where N is the number of rows found in the file.
431     """
432     if verbose:
433         print("Reading table from file `%s' to extract keys." % filename)
434     try:
435         fi = open(filename)
436     except IOError:
437         sys.exit("Error: Could not open `%s' for reading." % filename)
438
439     keys = []
440
441     if verbose:
442         print("Reader options:")
443         for name in "comment", "splitby", "keycol":
444             print("  %s: %r" % (name, getattr(options, name)))
445
446     for n, line in enumerate(fi):
447         line = line.strip()
448         if not line or line.startswith(options.comment):
449             continue
450
451         if line.count(options.comment):  # strip content after comment
452             line = line.split(options.comment)[0].strip()
453
454         row = [col.strip() for col in line.split(options.splitby)]
455
456         try:
457             key = row[options.keycol - 1]
458         except IndexError:
459             sys.exit(
460                 "%s:%d: Error: Cannot read key, not enough columns." % (filename, n + 1)
461             )
462
463         keys.append(key)
464
465     fi.close()
466
467     if not keys:
468         exit("Error: no keys found in file `%s'." % filename)
469
470     return keys
471
472
473 def read_template(filename):
474     if verbose:
475         print("Reading template from file `%s'" % filename)
476     try:
477         with open(filename, "r") as fi:
478             return fi.read()
479     except IOError:
480         sys.exit("Error: Could not open `%s' for reading." % filename)
481
482
483 def run_code(code):
484     tmpdir = tempfile.mkdtemp()
485     path = join(tmpdir, "t.py")
486     with open(path, "w") as fo:
487         fo.write(code)
488     try:
489         subprocess.check_call([sys.executable, path])
490     except subprocess.CalledProcessError as e:
491         raise AssertionError(e)
492     finally:
493         shutil.rmtree(tmpdir)
494
495
496 def main():
497     from optparse import OptionParser
498
499     usage = "usage: %prog [options] KEYS_FILE [TMPL_FILE]"
500
501     description = """\
502 Generates code for perfect hash functions from
503 a file with keywords and a code template.
504 If no template file is provided, a small built-in Python template
505 is processed and the output code is written to stdout.
506 """
507
508     parser = OptionParser(
509         usage=usage,
510         description=description,
511         prog=sys.argv[0],
512         version="%prog: " + __version__,
513     )
514
515     parser.add_option(
516         "--delimiter",
517         action="store",
518         default=", ",
519         help="Delimiter for list items used in output, "
520         "the default delimiter is '%default'",
521         metavar="STR",
522     )
523
524     parser.add_option(
525         "--indent",
526         action="store",
527         default=4,
528         type="int",
529         help="Make INT spaces at the beginning of a "
530         "new line when generated list is wrapped. "
531         "Default is %default",
532         metavar="INT",
533     )
534
535     parser.add_option(
536         "--width",
537         action="store",
538         default=76,
539         type="int",
540         help="Maximal width of generated list when "
541         "wrapped.  Default width is %default",
542         metavar="INT",
543     )
544
545     parser.add_option(
546         "--comment",
547         action="store",
548         default="#",
549         help="STR is the character, or sequence of "
550         "characters, which marks the beginning "
551         "of a comment (which runs till "
552         "the end of the line), in the input "
553         "KEYS_FILE. "
554         "Default is '%default'",
555         metavar="STR",
556     )
557
558     parser.add_option(
559         "--splitby",
560         action="store",
561         default=",",
562         help="STR is the character by which the columns "
563         "in the input KEYS_FILE are split. "
564         "Default is '%default'",
565         metavar="STR",
566     )
567
568     parser.add_option(
569         "--keycol",
570         action="store",
571         default=1,
572         type="int",
573         help="Specifies the column INT in the input "
574         "KEYS_FILE which contains the keys. "
575         "Default is %default, i.e. the first column.",
576         metavar="INT",
577     )
578
579     parser.add_option(
580         "--trials",
581         action="store",
582         default=5,
583         type="int",
584         help="Specifies the number of trials before "
585         "NG is increased.  A small INT will give "
586         "compute faster, but the array G will be "
587         "large.  A large INT will take longer to "
588         "compute but G will be smaller. "
589         "Default is %default",
590         metavar="INT",
591     )
592
593     parser.add_option(
594         "--hft",
595         action="store",
596         default=1,
597         type="int",
598         help="Hash function type INT.  Possible values "
599         "are 1 (StrSaltHash) and 2 (IntSaltHash). "
600         "The default is %default",
601         metavar="INT",
602     )
603
604     parser.add_option(
605         "-e",
606         "--execute",
607         action="store_true",
608         help="Execute the generated code within " "the Python interpreter.",
609     )
610
611     parser.add_option(
612         "-o",
613         "--output",
614         action="store",
615         help="Specify output FILE explicitly. "
616         "`-o std' means standard output. "
617         "`-o no' means no output. "
618         "By default, the file name is obtained "
619         "from the name of the template file by "
620         "substituting `tmpl' to `code'.",
621         metavar="FILE",
622     )
623
624     parser.add_option("-v", "--verbose", action="store_true", help="verbosity")
625
626     options, args = parser.parse_args()
627
628     if options.trials <= 0:
629         parser.error("trials before increasing N has to be larger than zero")
630
631     global trials
632     trials = options.trials
633
634     global verbose
635     verbose = options.verbose
636
637     if len(args) not in (1, 2):
638         parser.error("incorrect number of arguments")
639
640     if len(args) == 2 and not args[1].count("tmpl"):
641         parser.error("template filename does not contain 'tmpl'")
642
643     if options.hft == 1:
644         Hash = StrSaltHash
645     elif options.hft == 2:
646         Hash = IntSaltHash
647     else:
648         parser.error("Hash function %s not implemented." % options.hft)
649
650     # --------------------- end parsing and checking --------------
651
652     keys_file = args[0]
653
654     if verbose:
655         print("keys_file = %r" % keys_file)
656
657     keys = read_table(keys_file, options)
658
659     if verbose:
660         print("Number os keys: %d" % len(keys))
661
662     tmpl_file = args[1] if len(args) == 2 else None
663
664     if verbose:
665         print("tmpl_file = %r" % tmpl_file)
666
667     template = read_template(tmpl_file) if tmpl_file else None
668
669     if options.output:
670         outname = options.output
671     else:
672         if tmpl_file:
673             if "tmpl" not in tmpl_file:
674                 sys.exit("Hmm, template filename does not contain 'tmpl'")
675             outname = tmpl_file.replace("tmpl", "code")
676         else:
677             outname = "std"
678
679     if verbose:
680         print("outname = %r\n" % outname)
681
682     if outname == "std":
683         outstream = sys.stdout
684     elif outname == "no":
685         outstream = None
686     else:
687         try:
688             outstream = open(outname, "w")
689         except IOError:
690             sys.exit("Error: Could not open `%s' for writing." % outname)
691
692     code = generate_code(keys, Hash, template, options)
693
694     if options.execute or template == builtin_template(Hash):
695         if verbose:
696             print("Executing code...\n")
697         run_code(code)
698
699     if outstream:
700         outstream.write(code)
701         if not outname == "std":
702             outstream.close()
703
704
705 if __name__ == "__main__":
706     main()