3 # DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
5 # Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
7 # The contents of this file are subject to the terms of either the GNU Lesser
8 # General Public License Version 2.1 only ("LGPL") or the Common Development and
9 # Distribution License ("CDDL")(collectively, the "License"). You may not use this
10 # file except in compliance with the License. You can obtain a copy of the CDDL at
11 # http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
12 # http://www.opensource.org/licenses/lgpl-license.php. See the License for the
13 # specific language governing permissions and limitations under the License. When
14 # distributing the software, include this License Header Notice in each file and
15 # include the full text of the License in the License file as well as the
18 # NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
20 # For Covered Software in this distribution, this License shall be governed by the
21 # laws of the State of California (excluding conflict-of-law provisions).
22 # Any litigation relating to this License shall be subject to the jurisdiction of
23 # the Federal Courts of the Northern District of California and the state courts
24 # of the State of California, with venue lying in Santa Clara County, California.
28 # If you wish your version of this file to be governed by only the CDDL or only
29 # the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to
30 # include this software in this distribution under the [CDDL or LGPL Version 2.1]
31 # license." If you don't indicate a single choice of license, a recipient has the
32 # option to distribute your version of this file under either the CDDL or the LGPL
33 # Version 2.1, or to extend the choice of license to its licensees as provided
34 # above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL
35 # Version 2 license, then the option applies only if the new code is made subject
36 # to such option by the copyright holder.
38 __all__ = ['Trie', 'DATrie', 'match_longest', 'get_ambiguious_length']
50 self.root = Trie.TrieNode()
52 def add(self, word, value=1):
56 curr_node = curr_node.trans[ch]
58 curr_node.trans[ch] = Trie.TrieNode()
59 curr_node = curr_node.trans[ch]
63 def walk (self, trienode, ch):
64 if ch in trienode.trans:
65 trienode = trienode.trans[ch]
66 return trienode, trienode.val
70 class FlexibleList (list):
71 def __check_size (self, index):
72 if index >= len(self):
73 self.extend ([0] * (index-len(self)+1))
75 def __getitem__ (self, index):
76 self.__check_size (index)
77 return list.__getitem__(self, index)
79 def __setitem__ (self, index, value):
80 self.__check_size (index)
81 return list.__setitem__(self, index, value)
83 def character_based_encoder (ch, range=('a', 'z')):
84 ret = ord(ch) - ord(range[0]) + 1
85 if ret <= 0: ret = ord(range[1]) + 1
88 class DATrie (object):
89 def __init__(self, chr_encoder=character_based_encoder):
91 self.chr_encoder = chr_encoder
95 self.base = FlexibleList ()
96 self.check = FlexibleList ()
97 self.value = FlexibleList ()
99 def walk (self, s, ch):
100 c = self.chr_encoder (ch)
101 t = abs(self.base[s]) + c
103 if t<len(self.check) and self.check[t] == s and self.base[t]:
107 v = -1 if self.base[t] < 0 else 0
112 def find_base (self, s, children, i=1):
113 if s == 0 or not children:
120 k = i + self.chr_encoder (ch)
121 if self.base[k] or self.check[k] or k == s:
123 i += int (log (loop_times, 2)) + 1
130 def build (self, words, values=None):
131 assert (not values or (len(words) == len(values)))
132 itval = iter(values) if values else None
136 trie.add (w, itval.next() if itval else -1)
138 self.construct_from_trie (trie, values!=None)
140 def construct_from_trie (self, trie, with_value=True, progress_cb=None, progress_cb_thr=100):
141 nodes = [(trie.root, 0)]
146 trienode, s = nodes.pop(0)
147 find_from = b = self.find_base (s, trienode.trans, find_from)
148 self.base[s] = -b if trienode.val else b
149 if with_value: self.value[s] = trienode.val
151 for ch in trienode.trans:
152 c = self.chr_encoder (ch)
153 t = abs(self.base[s]) + c
154 self.check[t] = s if s else -1
156 nodes.append ((trienode.trans[ch], t))
159 if loop_times == progress_cb_thr:
164 for i in xrange (self.chr_encoder (max(trie.root.trans))+1):
165 if self.check[i] == -1:
168 def save (self, fname):
169 f = open (fname, 'w+')
172 using_32bits = l > 2**15
173 elm_size = 4 if using_32bits else 2
174 fmt_str = '%di'%l if using_32bits else '%dh'%l
176 # the data types here should be aligned with those in datrie.h
177 f.write (struct.pack ('I', l))
178 f.write (struct.pack ('H', elm_size))
179 f.write (struct.pack ('H', 1 if self.value else 0))
181 f.write (struct.pack (fmt_str, *self.base))
182 f.write (struct.pack (fmt_str, *self.check))
185 if len(self.value) < l: self.value[l-1] = 0
186 f.write (struct.pack ('%di'%l, *self.value))
190 def output_static_c_arrays (self, fname):
191 f = open(fname, 'w+')
194 type = "int" if l > 2**15 else "short"
196 f.write (self.__to_c_array (self.base, type, "base"))
197 f.write (self.__to_c_array (self.check, type, "check"))
198 f.write (self.__to_c_array (self.value, "int", "value"))
202 def __to_c_array (self, array, type, name):
203 return "static %s %s[] = {%s};\n\n" % (type, name, ', '.join (str(i) for i in array))
205 def load (self, fname):
206 f = open (fname, 'r')
208 l = struct.unpack ('I', f.read(4))[0]
209 elm_size = struct.unpack ('H', f.read(2))[0]
210 has_value = struct.unpack ('H', f.read(2))[0]
212 fmt_str = '%di'%l if elm_size == 4 else '%dh'%l
213 self.base = struct.unpack (fmt_str, f.read(l*elm_size))
214 self.check = struct.unpack (fmt_str, f.read(l*elm_size))
215 self.value = struct.unpack ('%di'%l, f.read(l*4)) if has_value else []
219 def search (trie, word):
220 curr_node = trie.root
223 curr_node, val = trie.walk (curr_node, ch)
231 def match_longest (trie, word):
232 l = ret_l = ret_v = 0
233 curr_node = trie.root
236 curr_node, val = trie.walk (curr_node, ch)
242 ret_l, ret_v = l, val
246 def get_ambiguious_length (trie, str, word_len):
248 while i < word_len and i < len(str):
249 wid, l = match_longest(trie, str[i:])
256 from pinyin_data import valid_syllables
259 for s in valid_syllables:
260 trie.add (s, valid_syllables[s])
262 for s in valid_syllables:
263 v, l = match_longest (trie, s+'b')
264 assert (len(s) == l and valid_syllables[s] == v)
267 datrie.construct_from_trie (trie)
269 datrie.save ('/tmp/trie_test')
270 datrie.load ('/tmp/trie_test')
272 for s in valid_syllables:
273 v, l = match_longest (datrie, s+'b')
274 assert (len(s) == l and valid_syllables[s] == v)
276 print 'test executed successfully'
278 if __name__ == "__main__":