Tizen 2.1 base
[platform/core/uifw/ise-engine-sunpinyin.git] / src / slm / slmprune / slmprune.cpp
1 /*
2  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
3  *
4  * Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
5  *
6  * The contents of this file are subject to the terms of either the GNU Lesser
7  * General Public License Version 2.1 only ("LGPL") or the Common Development and
8  * Distribution License ("CDDL")(collectively, the "License"). You may not use this
9  * file except in compliance with the License. You can obtain a copy of the CDDL at
10  * http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
11  * http://www.opensource.org/licenses/lgpl-license.php. See the License for the
12  * specific language governing permissions and limitations under the License. When
13  * distributing the software, include this License Header Notice in each file and
14  * include the full text of the License in the License file as well as the
15  * following notice:
16  *
17  * NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
18  * (CDDL)
19  * For Covered Software in this distribution, this License shall be governed by the
20  * laws of the State of California (excluding conflict-of-law provisions).
21  * Any litigation relating to this License shall be subject to the jurisdiction of
22  * the Federal Courts of the Northern District of California and the state courts
23  * of the State of California, with venue lying in Santa Clara County, California.
24  *
25  * Contributor(s):
26  *
27  * If you wish your version of this file to be governed by only the CDDL or only
28  * the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to
29  * include this software in this distribution under the [CDDL or LGPL Version 2.1]
30  * license." If you don't indicate a single choice of license, a recipient has the
31  * option to distribute your version of this file under either the CDDL or the LGPL
32  * Version 2.1, or to extend the choice of license to its licensees as provided
33  * above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL
34  * Version 2 license, then the option applies only if the new code is made subject
35  * to such option by the copyright holder.
36  */
37
38 #ifdef HAVE_CONFIG_H
39 #include "config.h"
40 #endif
41
42 #ifdef HAVE_ASSERT_H
43 #include <assert.h>
44 #endif
45
46 #include <stdio.h>
47 #include <math.h>
48
49 #include "../sim_slm.h"
50 #include <algorithm>
51
52 class TNodeInfo {
53 public:
54     double d;
55 #ifndef WORDS_BIGENDIAN
56     unsigned child : 1;
57     unsigned idx : 31;
58 #else
59     unsigned idx : 31;
60     unsigned child : 1;
61 #endif
62
63 public:
64     TNodeInfo(double distance = 0.0, int pos = 0, bool children =
65                   0) : d(distance)
66     {
67         idx = pos; child = (children == 0) ? 0 : 1;
68     }
69
70     bool
71     operator<(const TNodeInfo& r) const
72     {
73         return ((child ^ r.child) == 0) ? (d < r.d) : (child == 0);
74     }
75
76     bool
77     operator==(const TNodeInfo& r) const
78     {
79         return(child == r.child && d == r.d);
80     }
81 };
82
83 class CSlmPruner : public CSIMSlm {
84 public:
85     CSlmPruner() : CSIMSlm(), cut(NULL)
86     {
87     }
88
89     ~CSlmPruner()
90     {
91         if (cut) delete [] cut;
92     }
93
94     void SetCut(int* nCut);
95     void SetReserve(int* nReserve);
96     void Prune();
97     void Write(const char* filename);
98
99 protected:
100     void PruneLevel(int lvl);
101     double CalcDistance(int lvl, int* idx, TSIMWordId* hw);
102     void CalcBOW();
103
104 protected:
105     int* cut;
106     int cache_level, cache_idx; // to accelerate the pruning method
107     double cache_PA, cache_PB;
108 };
109
110 void
111 CSlmPruner::Prune()
112 {
113     printf("Erasing items using Entropy distance"); fflush(stdout);
114     for (int lvl = N; lvl > 0; --lvl)
115         PruneLevel(lvl);
116     printf("\n"); fflush(stdout);
117     CalcBOW();
118 }
119 void
120 CSlmPruner::Write(const char* filename)
121 {
122     FILE* out = fopen(filename, "wb");
123     fwrite(&N, sizeof(N), 1, out);
124     fwrite(&bUseLogPr, sizeof(bUseLogPr), 1, out);
125     fwrite(sz, sizeof(int), N + 1, out);
126     for (int i = 0; i < N; ++i) {
127         fwrite(level[i], sizeof(TNode), sz[i], out);
128     }
129     fwrite(level[N], sizeof(TLeaf), sz[N], out);
130     fclose(out);
131 }
132
133 void
134 CSlmPruner::SetReserve(int* nReserve)
135 {
136     cut = new int [N + 1];
137     cut[0] = 0;
138     for (int lvl = 1; lvl <= N; ++lvl) {
139         cut[lvl] = sz[lvl] - 1 - nReserve[lvl];
140         if (cut[lvl] < 0) cut[lvl] = 0;
141     }
142 }
143
144 void
145 CSlmPruner::SetCut(int* nCut)
146 {
147     cut = new int [N + 1];
148     cut[0] = 0;
149     for (int lvl = 1; lvl <= N; ++lvl)
150         cut[lvl] = nCut[lvl];
151 }
152
153 template <class chIterator>
154 int
155 CutLevel(CSIMSlm::TNode* pfirst,
156          CSIMSlm::TNode* plast,
157          chIterator chfirst,
158          chIterator chlast,
159          bool bUseLogPr)
160 {
161     int idxfirst, idxchk;
162     chIterator chchk = chfirst;
163     for (idxfirst = idxchk = 0; chchk != chlast; ++chchk, ++idxchk) {
164         //cut item whoese pr == 1.0; and not psuedo tail
165         if (chchk->pr != ((bUseLogPr) ? 0.0 : 1.0) || (chchk + 1) == chlast) {
166             if (idxfirst < idxchk) *chfirst = *chchk;
167             while (pfirst != plast && pfirst->child <= idxchk)
168                 pfirst++->child = idxfirst;
169             ++idxfirst;
170             ++chfirst;
171         }
172     }
173     return idxfirst;
174 }
175
176 void
177 CSlmPruner::PruneLevel(int lvl)
178 {
179     cache_level = cache_idx = -1;
180
181     if (cut[lvl] <= 0) {
182         printf("\n  Level %d (%d items), no need to cut as your command!",
183                lvl,
184                sz[lvl] - 1); fflush(stdout);
185         return;
186     }
187
188     printf("\n  Level %d (%d items), allocating...", lvl, sz[lvl] - 1); fflush(
189         stdout);
190
191     int n = sz[lvl] - 1; //do not count last psuedo tail
192     if (cut[lvl] >= n) cut[lvl] = n - 1;
193     TNodeInfo* pbuf = new TNodeInfo[n];
194     TSIMWordId hw[16]; // it should be lvl+1, yet some compiler do not support it
195     int idx[16];       // it should be lvl+1, yet some compiler do not support it
196
197     printf(", Calculating..."); fflush(stdout);
198     for (int i = 0; i <= lvl; ++i)
199         idx[i] = 0;
200     while (idx[lvl] < n) {
201         if (lvl == N) {
202             hw[lvl] = (((TLeaf*)level[lvl]) + idx[lvl])->id;
203         } else {
204             hw[lvl] = (((TNode*)level[lvl]) + idx[lvl])->id;
205         }
206         for (int j = lvl - 1; j >= 0; --j) {
207             TNode* pnode = ((TNode*)level[j]) + idx[j];
208             for (; (pnode + 1)->child <= idx[j + 1]; ++pnode, ++idx[j])
209                 ;
210             hw[j] = pnode->id;
211         }
212         bool has_child = false;
213         if (lvl != N) {
214             TNode* pn = ((TNode*)level[lvl]) + idx[lvl];
215             if ((pn + 1)->child > pn->child)
216                 has_child = true;
217         }
218         pbuf[idx[lvl]].child = (has_child) ? 1 : 0;
219         pbuf[idx[lvl]].idx = idx[lvl];
220         if (!has_child)
221             pbuf[idx[lvl]].d = CalcDistance(lvl, idx, hw);
222         ++idx[lvl];
223     }
224     printf(", sorting...");
225     std::make_heap(pbuf, pbuf + n);
226     std::sort_heap(pbuf, pbuf + n);
227
228     int k = 0;
229     // because pr in model can not be 1.0, so we use this to mark a item to be prune
230     for (TNodeInfo* pinfo = pbuf;
231          k < cut[lvl] && pinfo->child == 0;
232          ++k, ++pinfo) {
233         if (lvl == N) {
234             if (bUseLogPr)
235                 (((TLeaf*)level[lvl]) + pinfo->idx)->pr = 0.0;  // -log(1.0)
236             else
237                 (((TLeaf*)level[lvl]) + pinfo->idx)->pr = 1.0;
238         } else {
239             if (bUseLogPr)
240                 (((TNode*)level[lvl]) + pinfo->idx)->pr = 0.0;  // -log(1.0)
241             else
242                 (((TNode*)level[lvl]) + pinfo->idx)->pr = 1.0;  // -log(1.0)
243         }
244     }
245     printf("(cut %d items), build parent ptr...", k); fflush(stdout);
246     if (lvl == N) {
247         k =
248             CutLevel((TNode*)level[lvl - 1],
249                      ((TNode*)level[lvl - 1]) + sz[lvl - 1],
250                      (TLeaf*)level[lvl],
251                      ((TLeaf*)level[lvl]) + sz[lvl],
252                      bUseLogPr);
253     } else {
254         k =
255             CutLevel((TNode*)level[lvl - 1],
256                      ((TNode*)level[lvl - 1]) + sz[lvl - 1],
257                      (TNode*)level[lvl],
258                      ((TNode*)level[lvl]) + sz[lvl],
259                      bUseLogPr);
260     }
261     sz[lvl] = k; //k is new size
262     printf("done!");
263     delete [] pbuf;
264     cache_level = cache_idx = -1;
265 }
266
267 template<class chIterator>
268 double
269 CalcNodeBow(CSlmPruner* pruner,
270             int lvl,
271             TSIMWordId words[],
272             chIterator chh,
273             chIterator cht,
274             bool bUseLogPr)
275 {
276     double sumnext = 0.0, sum = 0.0;
277     if (chh == cht)
278         return 1.0;
279     for (; chh < cht; ++chh) {
280         if (bUseLogPr)
281             sumnext += exp(-double(chh->pr));
282         else
283             sumnext += double(chh->pr);
284         words[lvl + 1] = chh->id;
285         sum += pruner->getPr(lvl, words + 2);
286     }
287     assert(sumnext >= 0.0 && sumnext < 1.0);
288     assert(sum >= 0.0 && sum < 1.0);
289     return (1.0 - sumnext) / (1.0 - sum);
290 }
291
292 void
293 CSlmPruner::CalcBOW()
294 {
295     printf("\nUpdating back-off weight"); fflush(stdout);
296     for (int lvl = 0; lvl < N; ++lvl) {
297         printf("\n    Level %d...", lvl); fflush(stdout);
298         TNode* base[16]; //it should be lvl+1, yet some compiler do not support it
299         int idx[16];     //it should be lvl+1, yet some compiler do not support it
300         for (int i = 0; i <= lvl; ++i) {
301             base[i] = (TNode*)level[i];
302             idx[i] = 0;
303         }
304         TSIMWordId words[17];   //it should be lvl+2, yet some compiler do not support it
305         for (int lsz = sz[lvl] - 1; idx[lvl] < lsz; ++idx[lvl]) {
306             words[lvl] = base[lvl][idx[lvl]].id;
307             for (int k = lvl - 1; k >= 0; --k) {
308                 while (base[k][idx[k] + 1].child <= idx[k + 1])
309                     ++idx[k];
310                 words[k] = base[k][idx[k]].id;
311             }
312             TNode & node = base[lvl][idx[lvl]];
313             TNode & nodenext = *((&node) + 1);
314
315             double bow = 1.0;
316             if (lvl == N - 1) {
317                 TLeaf* ch = (TLeaf*)level[lvl + 1];
318                 bow =
319                     CalcNodeBow(this, lvl, words, &(ch[node.child]),
320                                 &(ch[nodenext.child]), bUseLogPr);
321             } else {
322                 TNode* ch = (TNode*)level[lvl + 1];
323                 bow =
324                     CalcNodeBow(this, lvl, words, &(ch[node.child]),
325                                 &(ch[nodenext.child]), bUseLogPr);
326             }
327             if (bUseLogPr)
328                 node.bow = PR_TYPE(-log(bow));
329             else
330                 node.bow = PR_TYPE(bow);
331         }
332     }
333     printf("\n"); fflush(stdout);
334 }
335
336 double
337 CSlmPruner::CalcDistance(int lvl, int* idx, TSIMWordId* hw)
338 {
339     double PA, PB, PHW, PH_W, PH, BOW, _BOW, pr, p_r;
340     TSIMWordId w = hw[lvl];
341
342     PH = 1.0;
343     TNode* parent = ((TNode*)level[lvl - 1]) + idx[lvl - 1];
344     if (bUseLogPr)
345         BOW = exp(-double(parent->bow));  //Fix original bug to use the BOW directly
346     else
347         BOW = double(parent->bow);
348
349     for (int i = 1; i < lvl; ++i)
350         PH *= getPr(i, hw + 1 + (lvl - i));
351     assert(PH <= 1.0 && PH > 0.0);
352
353     if (lvl == N) {
354         if (bUseLogPr)
355             PHW = exp(-((((TLeaf*)level[lvl]) + idx[lvl])->pr));
356         else
357             PHW = ((((TLeaf*)level[lvl]) + idx[lvl])->pr);
358         assert(w == (((TLeaf*)level[lvl]) + idx[lvl])->id);
359     } else {
360         if (bUseLogPr)
361             PHW = exp(-((((TNode*)level[lvl]) + idx[lvl])->pr));
362         else
363             PHW = ((((TNode*)level[lvl]) + idx[lvl])->pr);
364         assert(w == (((TNode*)level[lvl]) + idx[lvl])->id);
365     }
366     PH_W = getPr(lvl - 1, hw + 2);
367     assert(PHW > 0.0 && PHW < 1.0);
368     assert(PH_W > 0.0 && PH_W < 1.0);
369
370     if (cache_level != lvl - 1 || cache_idx != idx[lvl - 1]) {
371         cache_level = lvl - 1;
372         cache_idx = idx[lvl - 1];
373         cache_PA = cache_PB = 1.0;
374         for (int h = parent->child, t = (parent + 1)->child; h < t; ++h) {
375             TSIMWordId id;
376             if (lvl == N) {
377                 if (bUseLogPr)
378                     pr = exp(-((((TLeaf*)level[lvl]) + h)->pr));
379                 else
380                     pr = ((((TLeaf*)level[lvl]) + h)->pr);
381                 id = (((TLeaf*)level[lvl]) + h)->id;
382             } else {
383                 if (bUseLogPr)
384                     pr = exp(-((((TNode*)level[lvl]) + h)->pr));
385                 else
386                     pr = ((((TNode*)level[lvl]) + h)->pr);
387                 id = (((TNode*)level[lvl]) + h)->id;
388             }
389             assert(pr > 0.0 && pr < 1.0);
390             cache_PA -= pr;
391
392             hw[lvl] = id;
393             p_r = getPr(lvl - 1, hw + 2);  // Fix bug from pr = getPr(lvl-1, hw+1)
394             assert(p_r > 0.0 && p_r < 1.0);
395             cache_PB -= p_r;
396         }
397         assert(cache_PA > -0.01 && cache_PB > -0.01);
398         if (cache_PA < 0.00001 || cache_PB < 0.00001) {
399             printf("\n precision problem on %d gram:", lvl - 1);
400             for (int i = 1; i < lvl; ++i) printf("%d ", idx[i]);
401             printf("   ");
402             if (cache_PA < 0.00001) {
403                 printf("{1.0 - sigma p(w|h)} ==> 0.00001");
404                 cache_PA = 0.00001;
405             }
406             if (cache_PB < 0.00001) {
407                 printf("{1.0 - sigma p(w|h')} ==> 0.00001");
408                 cache_PB = 0.00001;
409             }
410         }
411     }
412     PA = cache_PA;
413     PB = cache_PB;
414
415     _BOW = (PA + PHW) / (PB + PH_W); // Fix bug from "(1.0-PA+PHW)/(1.0-PB+PH_W);"
416
417     assert(BOW > 0.0);
418     assert(_BOW > 0.0);
419     assert(PA + PHW < 1.01);     // %1 error rate
420     assert(PB + PH_W < 1.01);    // %1 error rate
421
422     /*
423      * PH = P(h), PHW = P(w|h), PH_W = P(w|h'), _BOW = bow'(h) (the new bow)
424      * BOW = bow(h) (the original bow), PA = sum_{w_i:C(w_i,h)=0} P(w_i|h),
425      * PB = sum_{w_i:C(w_i,h)=0} P(w_i|h')
426      */
427     return -(PH *
428              (PHW *
429               (log(PH_W) + log(_BOW) - log(PHW)) + PA * (log(_BOW) - log(BOW))));
430 }
431
432 void
433 ShowUsage(void)
434 {
435     printf("Usage:\n");
436     printf("    slmprune input_slm result_slm [R|C] num1 num2...\n");
437     printf("\nDescription:\n");
438     printf(
439         "\
440       This program uses entropy-based method to prune the size of back-off \n\
441   language model 'input_slm' to a specific size and write to 'result_slm'. \n\
442   the third parameter [R|C] means the following numbers is the number for\n\
443   (R)eserve or (C)ut. If (C)ut, the num[k] means how many items in level K\n\
444   would be cut. If (R)eserve, num[k] means how many item would be reserved\n\
445   in level k. \n\
446       Note that we do not ensure that during pruning process,  exactly the\n\
447   the given number of items are cut or reserved, because some items may \n\
448   contains high level children, so could not be cut. \n\
449       Also it's your responsiblity to give right number of arguments based\n\
450   on 'input_slm'.\n\
451 \nSee Also:\n\
452     To get information of the back-off language model, try 'slminfo'.\n\n");
453 }
454
455 int nCut[32];
456 const char* srcfilename, *tgtfilename;
457
458 int
459 main(int argc, char* argv[])
460 {
461     memset(nCut, 0, sizeof(nCut));
462     if (argc < 5) {
463         ShowUsage(); exit(100);
464     }
465     srcfilename = argv[1];
466     tgtfilename = argv[2];
467     bool bCut = (argv[3][0] == 'C' || argv[3][0] == 'c');
468
469     CSlmPruner pruner;
470     printf("Reading language model %s...", srcfilename); fflush(stdout);
471     pruner.Load(srcfilename);
472     printf("done!\n"); fflush(stdout);
473
474     for (int i = 4; i < argc && i < 100; ++i)
475         nCut[i - 3] = atoi(argv[i]);
476
477     if (bCut)
478         pruner.SetCut(nCut);
479     else
480         pruner.SetReserve(nCut);
481     pruner.Prune();
482
483     printf("Writing target language model %s...", tgtfilename); fflush(stdout);
484     pruner.Write(tgtfilename);
485     printf("done!\n\n"); fflush(stdout);
486
487     pruner.Free();
488     return 0;
489 }