Tizen 2.1 base
[platform/core/uifw/ise-engine-sunpinyin.git] / python / trie.py
1 #!/usr/bin/python
2
3 # DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
4
5 # Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
6
7 # The contents of this file are subject to the terms of either the GNU Lesser
8 # General Public License Version 2.1 only ("LGPL") or the Common Development and
9 # Distribution License ("CDDL")(collectively, the "License"). You may not use this
10 # file except in compliance with the License. You can obtain a copy of the CDDL at
11 # http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
12 # http://www.opensource.org/licenses/lgpl-license.php. See the License for the 
13 # specific language governing permissions and limitations under the License. When
14 # distributing the software, include this License Header Notice in each file and
15 # include the full text of the License in the License file as well as the
16 # following notice:
17
18 # NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
19 # (CDDL)
20 # For Covered Software in this distribution, this License shall be governed by the
21 # laws of the State of California (excluding conflict-of-law provisions).
22 # Any litigation relating to this License shall be subject to the jurisdiction of
23 # the Federal Courts of the Northern District of California and the state courts
24 # of the State of California, with venue lying in Santa Clara County, California.
25
26 # Contributor(s):
27
28 # If you wish your version of this file to be governed by only the CDDL or only
29 # the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to
30 # include this software in this distribution under the [CDDL or LGPL Version 2.1]
31 # license." If you don't indicate a single choice of license, a recipient has the
32 # option to distribute your version of this file under either the CDDL or the LGPL
33 # Version 2.1, or to extend the choice of license to its licensees as provided
34 # above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL
35 # Version 2 license, then the option applies only if the new code is made subject
36 # to such option by the copyright holder. 
37
38 __all__ = ['Trie', 'DATrie', 'match_longest', 'get_ambiguious_length']
39
40 from math import log
41 import struct
42
43 class Trie (object):
44     class TrieNode:
45         def __init__ (self):
46             self.val = 0
47             self.trans = {}
48
49     def __init__(self):
50         self.root = Trie.TrieNode()
51
52     def add(self, word, value=1):
53         curr_node = self.root
54         for ch in word:
55             try: 
56                 curr_node = curr_node.trans[ch]
57             except:
58                 curr_node.trans[ch] = Trie.TrieNode()
59                 curr_node = curr_node.trans[ch]
60
61         curr_node.val = value
62
63     def walk (self, trienode, ch):
64         if ch in trienode.trans:
65             trienode = trienode.trans[ch]
66             return trienode, trienode.val
67         else:
68             return None, 0
69
70 class FlexibleList (list):
71     def __check_size (self, index):
72         if index >= len(self):
73             self.extend ([0] * (index-len(self)+1))
74
75     def __getitem__ (self, index):
76         self.__check_size (index)
77         return list.__getitem__(self, index)
78
79     def __setitem__ (self, index, value):
80         self.__check_size (index)
81         return list.__setitem__(self, index, value)
82
83 def character_based_encoder (ch, range=('a', 'z')):
84     ret = ord(ch) - ord(range[0]) + 1
85     if ret <= 0: ret = ord(range[1]) + 1
86     return ret
87
88 class DATrie (object):
89     def __init__(self, chr_encoder=character_based_encoder):
90         self.root = 0
91         self.chr_encoder = chr_encoder
92         self.clear()
93
94     def clear (self):
95         self.base  = FlexibleList ()
96         self.check = FlexibleList ()
97         self.value = FlexibleList ()
98
99     def walk (self, s, ch):
100         c = self.chr_encoder (ch)
101         t = abs(self.base[s]) + c
102
103         if t<len(self.check) and self.check[t] == s and self.base[t]:
104             if self.value: 
105                 v = self.value[t]
106             else: 
107                 v = -1 if self.base[t] < 0 else 0
108             return t, v
109         else:
110             return 0, 0
111
112     def find_base (self, s, children, i=1):
113         if s == 0 or not children:
114             return s
115
116         i = max (i, 1)
117         loop_times = 0
118         while True:
119             for ch in children:
120                 k = i + self.chr_encoder (ch)
121                 if self.base[k] or self.check[k] or k == s:
122                     loop_times += 1
123                     i += int (log (loop_times, 2)) + 1
124                     break
125             else:
126                 break
127
128         return i
129
130     def build (self, words, values=None):
131         assert (not values or (len(words) == len(values)))
132         itval = iter(values) if values else None
133
134         trie = Trie()
135         for w in words:
136             trie.add (w, itval.next() if itval else -1)
137
138         self.construct_from_trie (trie, values!=None)
139
140     def construct_from_trie (self, trie, with_value=True, progress_cb=None, progress_cb_thr=100):
141         nodes = [(trie.root, 0)]
142         find_from = 1
143         loop_times = 0
144
145         while nodes:
146             trienode, s = nodes.pop(0)
147             find_from = b = self.find_base (s, trienode.trans, find_from)
148             self.base[s] = -b if trienode.val else b
149             if with_value: self.value[s] = trienode.val
150
151             for ch in trienode.trans:
152                 c = self.chr_encoder (ch)
153                 t = abs(self.base[s]) + c
154                 self.check[t] = s if s else -1
155
156                 nodes.append ((trienode.trans[ch], t))
157
158             loop_times += 1
159             if loop_times == progress_cb_thr:
160                 loop_times = 0
161                 if progress_cb:
162                     progress_cb ()
163
164         for i in xrange (self.chr_encoder (max(trie.root.trans))+1):
165             if self.check[i] == -1:
166                 self.check[i] = 0
167
168     def save (self, fname):
169         f = open (fname, 'w+')
170         l = len (self.base)
171
172         using_32bits = l > 2**15
173         elm_size = 4 if using_32bits else 2
174         fmt_str = '%di'%l if using_32bits else '%dh'%l
175
176         # the data types here should be aligned with those in datrie.h
177         f.write (struct.pack ('I', l))
178         f.write (struct.pack ('H', elm_size))
179         f.write (struct.pack ('H', 1 if self.value else 0))
180
181         f.write (struct.pack (fmt_str, *self.base))
182         f.write (struct.pack (fmt_str, *self.check))
183
184         if self.value:
185             if len(self.value) < l: self.value[l-1] = 0
186             f.write (struct.pack ('%di'%l, *self.value))
187
188         f.close()
189
190     def output_static_c_arrays (self, fname):
191         f = open(fname, 'w+')
192         l = len (self.base)
193
194         type = "int" if l > 2**15 else "short"
195
196         f.write (self.__to_c_array (self.base,  type,  "base"))
197         f.write (self.__to_c_array (self.check, type,  "check"))
198         f.write (self.__to_c_array (self.value, "int", "value"))
199
200         f.close()
201
202     def __to_c_array (self, array, type, name):
203         return "static %s %s[] = {%s};\n\n" % (type, name, ', '.join (str(i) for i in array))
204
205     def load (self, fname):
206         f = open (fname, 'r')
207
208         l = struct.unpack ('I', f.read(4))[0]
209         elm_size = struct.unpack ('H', f.read(2))[0]
210         has_value = struct.unpack ('H', f.read(2))[0]
211
212         fmt_str = '%di'%l if elm_size == 4 else '%dh'%l
213         self.base  = struct.unpack (fmt_str, f.read(l*elm_size))
214         self.check = struct.unpack (fmt_str, f.read(l*elm_size))
215         self.value = struct.unpack ('%di'%l, f.read(l*4)) if has_value else []
216
217         f.close()
218
219 def search (trie, word):
220     curr_node = trie.root
221
222     for ch in word:
223         curr_node, val = trie.walk (curr_node, ch)
224         if not curr_node: 
225             break
226     else:
227         return val
228
229     return 0
230
231 def match_longest (trie, word):
232     l = ret_l = ret_v = 0
233     curr_node = trie.root
234
235     for ch in word:
236         curr_node, val = trie.walk (curr_node, ch)
237         if not curr_node: 
238             break
239
240         l += 1
241         if val: 
242             ret_l, ret_v = l, val
243
244     return ret_v, ret_l
245
246 def get_ambiguious_length (trie, str, word_len):
247     i = 1
248     while i < word_len and i < len(str):
249         wid, l = match_longest(trie, str[i:])
250         if word_len < i + l:
251             word_len = i + l
252         i += 1
253     return i
254
255 def test ():
256     from pinyin_data import valid_syllables
257
258     trie = Trie()
259     for s in valid_syllables:
260         trie.add (s, valid_syllables[s])
261
262     for s in valid_syllables:
263         v, l = match_longest (trie, s+'b')
264         assert (len(s) == l and valid_syllables[s] == v)
265
266     datrie = DATrie()
267     datrie.construct_from_trie (trie)
268
269     datrie.save ('/tmp/trie_test')
270     datrie.load ('/tmp/trie_test')
271
272     for s in valid_syllables:
273         v, l = match_longest (datrie, s+'b')
274         assert (len(s) == l and valid_syllables[s] == v)
275
276     print 'test executed successfully'
277
278 if __name__ == "__main__":
279     test ()