Initial import to Tizen
[profile/ivi/sphinxbase.git] / src / libsphinxbase / lm / lm3g_templates.c
1 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /* ====================================================================
3  * Copyright (c) 1999-2007 Carnegie Mellon University.  All rights
4  * reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer. 
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  *
18  * This work was supported in part by funding from the Defense Advanced 
19  * Research Projects Agency and the National Science Foundation of the 
20  * United States of America, and the CMU Sphinx Speech Consortium.
21  *
22  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
23  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
24  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
25  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
26  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
28  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
29  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
30  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
31  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
32  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  *
34  * ====================================================================
35  *
36  */
37 /*
38  * \file lm3g_templates.c Core Sphinx 3-gram code used in
39  * DMP/DMP32/ARPA (for now) model code.
40  */
41
42 #include <assert.h>
43
44 /* Locate a specific bigram within a bigram list */
45 #define BINARY_SEARCH_THRESH    16
46 static int32
47 find_bg(bigram_t * bg, int32 n, int32 w)
48 {
49     int32 i, b, e;
50
51     /* Binary search until segment size < threshold */
52     b = 0;
53     e = n;
54     while (e - b > BINARY_SEARCH_THRESH) {
55         i = (b + e) >> 1;
56         if (bg[i].wid < w)
57             b = i + 1;
58         else if (bg[i].wid > w)
59             e = i;
60         else
61             return i;
62     }
63
64     /* Linear search within narrowed segment */
65     for (i = b; (i < e) && (bg[i].wid != w); i++);
66     return ((i < e) ? i : -1);
67 }
68
69 static int32
70 lm3g_bg_score(NGRAM_MODEL_TYPE *model,
71               int32 lw1, int32 lw2, int32 *n_used)
72 {
73     int32 i, n, b, score;
74     bigram_t *bg;
75
76     if (lw1 < 0 || model->base.n < 2) {
77         *n_used = 1;
78         return model->lm3g.unigrams[lw2].prob1.l;
79     }
80
81     b = FIRST_BG(model, lw1);
82     n = FIRST_BG(model, lw1 + 1) - b;
83     bg = model->lm3g.bigrams + b;
84
85     if ((i = find_bg(bg, n, lw2)) >= 0) {
86         /* Access mode = bigram */
87         *n_used = 2;
88         score = model->lm3g.prob2[bg[i].prob2].l;
89     }
90     else {
91         /* Access mode = unigram */
92         *n_used = 1;
93         score = model->lm3g.unigrams[lw1].bo_wt1.l + model->lm3g.unigrams[lw2].prob1.l;
94     }
95
96     return (score);
97 }
98
99 static void
100 load_tginfo(NGRAM_MODEL_TYPE *model, int32 lw1, int32 lw2)
101 {
102     int32 i, n, b, t;
103     bigram_t *bg;
104     tginfo_t *tginfo;
105
106     /* First allocate space for tg information for bg lw1,lw2 */
107     tginfo = (tginfo_t *) listelem_malloc(model->lm3g.le);
108     tginfo->w1 = lw1;
109     tginfo->tg = NULL;
110     tginfo->next = model->lm3g.tginfo[lw2];
111     model->lm3g.tginfo[lw2] = tginfo;
112
113     /* Locate bigram lw1,lw2 */
114     b = model->lm3g.unigrams[lw1].bigrams;
115     n = model->lm3g.unigrams[lw1 + 1].bigrams - b;
116     bg = model->lm3g.bigrams + b;
117
118     if ((n > 0) && ((i = find_bg(bg, n, lw2)) >= 0)) {
119         tginfo->bowt = model->lm3g.bo_wt2[bg[i].bo_wt2].l;
120
121         /* Find t = Absolute first trigram index for bigram lw1,lw2 */
122         b += i;                 /* b = Absolute index of bigram lw1,lw2 on disk */
123         t = FIRST_TG(model, b);
124
125         tginfo->tg = model->lm3g.trigrams + t;
126
127         /* Find #tg for bigram w1,w2 */
128         tginfo->n_tg = FIRST_TG(model, b + 1) - t;
129     }
130     else {                      /* No bigram w1,w2 */
131         tginfo->bowt = 0;
132         tginfo->n_tg = 0;
133     }
134 }
135
136 /* Similar to find_bg */
137 static int32
138 find_tg(trigram_t * tg, int32 n, int32 w)
139 {
140     int32 i, b, e;
141
142     b = 0;
143     e = n;
144     while (e - b > BINARY_SEARCH_THRESH) {
145         i = (b + e) >> 1;
146         if (tg[i].wid < w)
147             b = i + 1;
148         else if (tg[i].wid > w)
149             e = i;
150         else
151             return i;
152     }
153
154     for (i = b; (i < e) && (tg[i].wid != w); i++);
155     return ((i < e) ? i : -1);
156 }
157
158 static int32
159 lm3g_tg_score(NGRAM_MODEL_TYPE *model, int32 lw1,
160               int32 lw2, int32 lw3, int32 *n_used)
161 {
162     ngram_model_t *base = &model->base;
163     int32 i, n, score;
164     trigram_t *tg;
165     tginfo_t *tginfo, *prev_tginfo;
166
167     if ((base->n < 3) || (lw1 < 0) || (lw2 < 0))
168         return (lm3g_bg_score(model, lw2, lw3, n_used));
169
170     prev_tginfo = NULL;
171     for (tginfo = model->lm3g.tginfo[lw2]; tginfo; tginfo = tginfo->next) {
172         if (tginfo->w1 == lw1)
173             break;
174         prev_tginfo = tginfo;
175     }
176
177     if (!tginfo) {
178         load_tginfo(model, lw1, lw2);
179         tginfo = model->lm3g.tginfo[lw2];
180     }
181     else if (prev_tginfo) {
182         prev_tginfo->next = tginfo->next;
183         tginfo->next = model->lm3g.tginfo[lw2];
184         model->lm3g.tginfo[lw2] = tginfo;
185     }
186
187     tginfo->used = 1;
188
189     /* Trigrams for w1,w2 now pointed to by tginfo */
190     n = tginfo->n_tg;
191     tg = tginfo->tg;
192     if ((i = find_tg(tg, n, lw3)) >= 0) {
193         /* Access mode = trigram */
194         *n_used = 3;
195         score = model->lm3g.prob3[tg[i].prob3].l;
196     }
197     else {
198         score = tginfo->bowt + lm3g_bg_score(model, lw2, lw3, n_used);
199     }
200
201     return (score);
202 }
203
204 static int32
205 lm3g_template_score(ngram_model_t *base, int32 wid,
206                       int32 *history, int32 n_hist,
207                       int32 *n_used)
208 {
209     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
210     switch (n_hist) {
211     case 0:
212         /* Access mode: unigram */
213         *n_used = 1;
214         return model->lm3g.unigrams[wid].prob1.l;
215     case 1:
216         return lm3g_bg_score(model, history[0], wid, n_used);
217     case 2:
218     default:
219         /* Anything greater than 2 is the same as a trigram for now. */
220         return lm3g_tg_score(model, history[1], history[0], wid, n_used);
221     }
222 }
223
224 static int32
225 lm3g_template_raw_score(ngram_model_t *base, int32 wid,
226                         int32 *history, int32 n_hist,
227                           int32 *n_used)
228 {
229     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
230     int32 score;
231
232     switch (n_hist) {
233     case 0:
234         /* Access mode: unigram */
235         *n_used = 1;
236         /* Undo insertion penalty. */
237         score = model->lm3g.unigrams[wid].prob1.l - base->log_wip;
238         /* Undo language weight. */
239         score = (int32)(score / base->lw);
240         /* Undo unigram interpolation */
241         if (strcmp(base->word_str[wid], "<s>") != 0) { /* FIXME: configurable start_sym */
242             score = logmath_log(base->lmath,
243                                 logmath_exp(base->lmath, score)
244                                 - logmath_exp(base->lmath, 
245                                               base->log_uniform + base->log_uniform_weight));
246         }
247         return score;
248     case 1:
249         score = lm3g_bg_score(model, history[0], wid, n_used);
250         break;
251     case 2:
252     default:
253         /* Anything greater than 2 is the same as a trigram for now. */
254         score = lm3g_tg_score(model, history[1], history[0], wid, n_used);
255         break;
256     }
257     /* FIXME (maybe): This doesn't undo unigram weighting in backoff cases. */
258     return (int32)((score - base->log_wip) / base->lw);
259 }
260
261 static int32
262 lm3g_template_add_ug(ngram_model_t *base,
263                        int32 wid, int32 lweight)
264 {
265     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
266     return lm3g_add_ug(base, &model->lm3g, wid, lweight);
267 }
268
269 static void
270 lm3g_template_flush(ngram_model_t *base)
271 {
272     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
273     lm3g_tginfo_reset(base, &model->lm3g);
274 }
275
276 typedef struct lm3g_iter_s {
277     ngram_iter_t base;
278     unigram_t *ug;
279     bigram_t *bg;
280     trigram_t *tg;
281 } lm3g_iter_t;
282
283 static ngram_iter_t *
284 lm3g_template_iter(ngram_model_t *base, int32 wid,
285                    int32 *history, int32 n_hist)
286 {
287     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
288     lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
289
290     ngram_iter_init((ngram_iter_t *)itor, base, n_hist, FALSE);
291
292     if (n_hist == 0) {
293         /* Unigram is the easiest. */
294         itor->ug = model->lm3g.unigrams + wid;
295         return (ngram_iter_t *)itor;
296     }
297     else if (n_hist == 1) {
298         int32 i, n, b;
299         /* Find the bigram, as in bg_score above (duplicate code...) */
300         itor->ug = model->lm3g.unigrams + history[0];
301         b = FIRST_BG(model, history[0]);
302         n = FIRST_BG(model, history[0] + 1) - b;
303         itor->bg = model->lm3g.bigrams + b;
304         /* If no such bigram exists then fail. */
305         if ((i = find_bg(itor->bg, n, wid)) < 0) {
306             ngram_iter_free((ngram_iter_t *)itor);
307             return NULL;
308         }
309         itor->bg += i;
310         return (ngram_iter_t *)itor;
311     }
312     else if (n_hist == 2) {
313         int32 i, n;
314         tginfo_t *tginfo, *prev_tginfo;
315         /* Find the trigram, as in tg_score above (duplicate code...) */
316         itor->ug = model->lm3g.unigrams + history[1];
317         prev_tginfo = NULL;
318         for (tginfo = model->lm3g.tginfo[history[0]];
319              tginfo; tginfo = tginfo->next) {
320             if (tginfo->w1 == history[1])
321                 break;
322             prev_tginfo = tginfo;
323         }
324
325         if (!tginfo) {
326             load_tginfo(model, history[1], history[0]);
327             tginfo = model->lm3g.tginfo[history[0]];
328         }
329         else if (prev_tginfo) {
330             prev_tginfo->next = tginfo->next;
331             tginfo->next = model->lm3g.tginfo[history[0]];
332             model->lm3g.tginfo[history[0]] = tginfo;
333         }
334
335         tginfo->used = 1;
336
337         /* Trigrams for w1,w2 now pointed to by tginfo */
338         n = tginfo->n_tg;
339         itor->tg = tginfo->tg;
340         if ((i = find_tg(itor->tg, n, wid)) >= 0) {
341             itor->tg += i;
342             /* Now advance the bigram pointer accordingly.  FIXME:
343              * Note that we actually already found the relevant bigram
344              * in load_tginfo. */
345             itor->bg = model->lm3g.bigrams;
346             while (FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))
347                    <= (itor->tg - model->lm3g.trigrams))
348                 ++itor->bg;
349             return (ngram_iter_t *)itor;
350         }
351         else {
352             ngram_iter_free((ngram_iter_t *)itor);
353             return (ngram_iter_t *)NULL;
354         }
355     }
356     else {
357         /* Should not happen. */
358         assert(n_hist == 0); /* Guaranteed to fail. */
359         ngram_iter_free((ngram_iter_t *)itor);
360         return NULL;
361     }
362 }
363
364 static ngram_iter_t *
365 lm3g_template_mgrams(ngram_model_t *base, int m)
366 {
367     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
368     lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
369     ngram_iter_init((ngram_iter_t *)itor, base, m, FALSE);
370
371     itor->ug = model->lm3g.unigrams;
372     itor->bg = model->lm3g.bigrams;
373     itor->tg = model->lm3g.trigrams;
374
375     /* Advance bigram pointer to match first trigram. */
376     if (m > 1 && base->n_counts[1] > 1)  {
377         while (FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))
378                <= (itor->tg - model->lm3g.trigrams))
379             ++itor->bg;
380     }
381
382     /* Advance unigram pointer to match first bigram. */
383     if (m > 0 && base->n_counts[0] > 1) {
384         while (itor->ug[1].bigrams <= (itor->bg - model->lm3g.bigrams))
385             ++itor->ug;
386     }
387
388     return (ngram_iter_t *)itor;
389 }
390
391 static ngram_iter_t *
392 lm3g_template_successors(ngram_iter_t *bitor)
393 {
394     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)bitor->model;
395     lm3g_iter_t *from = (lm3g_iter_t *)bitor;
396     lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
397
398     itor->ug = from->ug;
399     switch (bitor->m) {
400     case 0:
401         /* Next itor bigrams is the same as this itor bigram or
402            itor bigrams is more than total count. This means no successors */
403         if (((itor->ug + 1) - model->lm3g.unigrams < bitor->model->n_counts[0] &&
404             itor->ug->bigrams == (itor->ug + 1)->bigrams) || 
405             itor->ug->bigrams == bitor->model->n_counts[1])
406             goto done;
407             
408         /* Start iterating from first bigram successor of from->ug. */
409         itor->bg = model->lm3g.bigrams + itor->ug->bigrams;
410         break;
411     case 1:
412         itor->bg = from->bg;
413         
414         /* This indicates no successors */
415         if (((itor->bg + 1) - model->lm3g.bigrams < bitor->model->n_counts[1] &&
416             FIRST_TG (model, itor->bg - model->lm3g.bigrams) == 
417             FIRST_TG (model, (itor->bg + 1) - model->lm3g.bigrams)) ||
418             FIRST_TG (model, itor->bg - model->lm3g.bigrams) == bitor->model->n_counts[2])
419             goto done;
420             
421         /* Start iterating from first trigram successor of from->bg. */
422         itor->tg = (model->lm3g.trigrams 
423                     + FIRST_TG(model, (itor->bg - model->lm3g.bigrams)));
424 #if 0
425         printf("%s %s => %d (%s)\n",
426                model->base.word_str[itor->ug - model->lm3g.unigrams],
427                model->base.word_str[itor->bg->wid],
428                FIRST_TG(model, (itor->bg - model->lm3g.bigrams)),
429                model->base.word_str[itor->tg->wid]);
430 #endif
431         break;
432     case 2:
433     default:
434         /* All invalid! */
435         goto done;
436     }
437
438     ngram_iter_init((ngram_iter_t *)itor, bitor->model, bitor->m + 1, TRUE);
439     return (ngram_iter_t *)itor;
440     done:
441         ckd_free(itor);
442         return NULL;
443 }
444
445 static int32 const *
446 lm3g_template_iter_get(ngram_iter_t *base,
447                        int32 *out_score, int32 *out_bowt)
448 {
449     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base->model;
450     lm3g_iter_t *itor = (lm3g_iter_t *)base;
451
452     base->wids[0] = itor->ug - model->lm3g.unigrams;
453     if (itor->bg) base->wids[1] = itor->bg->wid;
454     if (itor->tg) base->wids[2] = itor->tg->wid;
455 #if 0
456     printf("itor_get: %d %d %d\n", base->wids[0], base->wids[1], base->wids[2]);
457 #endif
458
459     switch (base->m) {
460     case 0:
461         *out_score = itor->ug->prob1.l;
462         *out_bowt = itor->ug->bo_wt1.l;
463         break;
464     case 1:
465         *out_score = model->lm3g.prob2[itor->bg->prob2].l;
466         if (model->lm3g.bo_wt2)
467             *out_bowt = model->lm3g.bo_wt2[itor->bg->bo_wt2].l;
468         else
469             *out_bowt = 0;
470         break;
471     case 2:
472         *out_score = model->lm3g.prob3[itor->tg->prob3].l;
473         *out_bowt = 0;
474         break;
475     default: /* Should not happen. */
476         return NULL;
477     }
478     return base->wids;
479 }
480
481 static ngram_iter_t *
482 lm3g_template_iter_next(ngram_iter_t *base)
483 {
484     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base->model;
485     lm3g_iter_t *itor = (lm3g_iter_t *)base;
486
487     switch (base->m) {
488     case 0:
489         ++itor->ug;
490         /* Check for end condition. */
491         if (itor->ug - model->lm3g.unigrams >= base->model->n_counts[0])
492             goto done;
493         break;
494     case 1:
495         ++itor->bg;
496         /* Check for end condition. */
497         if (itor->bg - model->lm3g.bigrams >= base->model->n_counts[1])
498             goto done;
499         /* Advance unigram pointer if necessary in order to get one
500          * that points to this bigram. */
501         while (itor->bg - model->lm3g.bigrams >= itor->ug[1].bigrams) {
502             /* Stop if this is a successor iterator, since we don't
503              * want a new unigram. */
504             if (base->successor)
505                 goto done;
506             ++itor->ug;
507             if (itor->ug == model->lm3g.unigrams + base->model->n_counts[0]) {
508                 E_ERROR("Bigram %d has no valid unigram parent\n",
509                         itor->bg - model->lm3g.bigrams);
510                 goto done;
511             }
512         }
513         break;
514     case 2:
515         ++itor->tg;
516         /* Check for end condition. */
517         if (itor->tg - model->lm3g.trigrams >= base->model->n_counts[2])
518             goto done;
519         /* Advance bigram pointer if necessary. */
520         while (itor->tg - model->lm3g.trigrams >=
521             FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))) {
522             if (base->successor)
523                 goto done;
524             ++itor->bg;
525             if (itor->bg == model->lm3g.bigrams + base->model->n_counts[1]) {
526                 E_ERROR("Trigram %d has no valid bigram parent\n",
527                         itor->tg - model->lm3g.trigrams);
528
529                goto done;
530             }
531         }
532         /* Advance unigram pointer if necessary. */
533         while (itor->bg - model->lm3g.bigrams >= itor->ug[1].bigrams) {
534             ++itor->ug;
535             if (itor->ug == model->lm3g.unigrams + base->model->n_counts[0]) {
536                 E_ERROR("Trigram %d has no valid unigram parent\n",
537                         itor->tg - model->lm3g.trigrams);
538                 goto done;
539             }
540         }
541         break;
542     default: /* Should not happen. */
543         goto done;
544     }
545
546     return (ngram_iter_t *)itor;
547 done:
548     ngram_iter_free(base);
549     return NULL;
550 }
551
552 static void
553 lm3g_template_iter_free(ngram_iter_t *base)
554 {
555     ckd_free(base);
556 }