2 # Copyright (c) 2014 The Native Client Authors. All rights reserved.
3 # Use of this source code is governed by a BSD-style license that can be
4 # found in the LICENSE file.
11 class TrieTest(unittest.TestCase):
13 def MakeUncompressedTrie(self):
14 uncompressed = trie.Node()
15 accept = trie.AcceptInfo(input_rr='%eax', output_rr='%edx')
16 trie.AddToUncompressedTrie(uncompressed, ['0', '1', '2'], accept)
17 trie.AddToUncompressedTrie(uncompressed, ['0', '1', '2', '3'], accept)
18 trie.AddToUncompressedTrie(uncompressed, ['0', '1', '3'], accept)
19 trie.AddToUncompressedTrie(uncompressed, ['0', '1', '4'], accept)
20 trie.AddToUncompressedTrie(uncompressed, ['0', '1', '5'], accept)
23 def CheckTrieAccepts(self, accept_sequences):
24 accept = trie.AcceptInfo(input_rr='%eax', output_rr='%edx')
25 self.assertEquals([(accept, ['0', '1', '2']),
26 (accept, ['0', '1', '2', '3']),
27 (accept, ['0', '1', '3']),
28 (accept, ['0', '1', '4']),
29 (accept, ['0', '1', '5'])],
32 def testTrieAddAndMerge(self):
33 uncompressed = self.MakeUncompressedTrie()
34 self.CheckTrieAccepts(trie.GetAllAcceptSequences(uncompressed))
35 # n0 -0-> n1 -1-> n2 -2-> n3 -3-> n4
39 self.assertEquals(8, len(trie.GetAllUniqueNodes(uncompressed)))
41 node_cache = trie.NodeCache()
42 compressed_trie = node_cache.Merge(node_cache.empty_node, uncompressed)
43 self.CheckTrieAccepts(trie.GetAllAcceptSequences(compressed_trie))
44 # (n4, n5. n6, n7) can be grouped together from above
45 self.assertEquals(5, len(trie.GetAllUniqueNodes(compressed_trie)))
47 def testTrieSerializationAndDeserialization(self):
48 uncompressed = self.MakeUncompressedTrie()
49 node_cache = trie.NodeCache()
50 compressed_trie = node_cache.Merge(node_cache.empty_node, uncompressed)
51 reconstructed_trie = trie.TrieFromDict(trie.TrieToDict(compressed_trie),
53 self.CheckTrieAccepts(trie.GetAllAcceptSequences(reconstructed_trie))
54 self.assertEquals(5, len(trie.GetAllUniqueNodes(reconstructed_trie)))
56 def testTrieDiff(self):
59 accept1 = trie.AcceptInfo(input_rr='%eax', output_rr='%edx')
60 accept2 = trie.AcceptInfo(input_rr='%eax', output_rr='%ecx')
62 trie.AddToUncompressedTrie(trie1, ['0', '1', '2'], accept1)
63 trie.AddToUncompressedTrie(trie1, ['0', '1', '3'], accept1)
64 trie.AddToUncompressedTrie(trie1, ['0', '1', '4'], accept1)
65 trie.AddToUncompressedTrie(trie1, ['0', '1', '5'], accept1)
67 trie.AddToUncompressedTrie(trie2, ['0', '1', '2'], accept1)
68 trie.AddToUncompressedTrie(trie2, ['0', '1', '3'], accept1)
69 trie.AddToUncompressedTrie(trie2, ['0', '1', '4'], accept2)
71 node_cache = trie.NodeCache()
72 compressed_trie1 = node_cache.Merge(node_cache.empty_node, trie1)
73 compressed_trie2 = node_cache.Merge(node_cache.empty_node, trie2)
76 compressed_diffs = set()
78 trie.DiffTries(trie1, trie2, node_cache.empty_node,
80 trie.DiffTries(compressed_trie1, compressed_trie2, node_cache.empty_node,
81 compressed_diffs.add, ())
85 set([(('0', '1', '4'), accept1, accept2),
86 (('0', '1', '5'), accept1, None)]))
87 self.assertEquals(diffs, compressed_diffs)
90 if __name__ == '__main__':