1 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /* ====================================================================
3 * Copyright (c) 1999-2007 Carnegie Mellon University. All rights
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
10 * 1. Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 * notice, this list of conditions and the following disclaimer in
15 * the documentation and/or other materials provided with the
18 * This work was supported in part by funding from the Defense Advanced
19 * Research Projects Agency and the National Science Foundation of the
20 * United States of America, and the CMU Sphinx Speech Consortium.
22 * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
23 * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
24 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
25 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
26 * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 * ====================================================================
38 * \file lm3g_templates.c Core Sphinx 3-gram code used in
39 * DMP/DMP32/ARPA (for now) model code.
44 /* Locate a specific bigram within a bigram list */
45 #define BINARY_SEARCH_THRESH 16
47 find_bg(bigram_t * bg, int32 n, int32 w)
51 /* Binary search until segment size < threshold */
54 while (e - b > BINARY_SEARCH_THRESH) {
58 else if (bg[i].wid > w)
64 /* Linear search within narrowed segment */
65 for (i = b; (i < e) && (bg[i].wid != w); i++);
66 return ((i < e) ? i : -1);
70 lm3g_bg_score(NGRAM_MODEL_TYPE *model,
71 int32 lw1, int32 lw2, int32 *n_used)
76 if (lw1 < 0 || model->base.n < 2) {
78 return model->lm3g.unigrams[lw2].prob1.l;
81 b = FIRST_BG(model, lw1);
82 n = FIRST_BG(model, lw1 + 1) - b;
83 bg = model->lm3g.bigrams + b;
85 if ((i = find_bg(bg, n, lw2)) >= 0) {
86 /* Access mode = bigram */
88 score = model->lm3g.prob2[bg[i].prob2].l;
91 /* Access mode = unigram */
93 score = model->lm3g.unigrams[lw1].bo_wt1.l + model->lm3g.unigrams[lw2].prob1.l;
100 load_tginfo(NGRAM_MODEL_TYPE *model, int32 lw1, int32 lw2)
106 /* First allocate space for tg information for bg lw1,lw2 */
107 tginfo = (tginfo_t *) listelem_malloc(model->lm3g.le);
110 tginfo->next = model->lm3g.tginfo[lw2];
111 model->lm3g.tginfo[lw2] = tginfo;
113 /* Locate bigram lw1,lw2 */
114 b = model->lm3g.unigrams[lw1].bigrams;
115 n = model->lm3g.unigrams[lw1 + 1].bigrams - b;
116 bg = model->lm3g.bigrams + b;
118 if ((n > 0) && ((i = find_bg(bg, n, lw2)) >= 0)) {
119 tginfo->bowt = model->lm3g.bo_wt2[bg[i].bo_wt2].l;
121 /* Find t = Absolute first trigram index for bigram lw1,lw2 */
122 b += i; /* b = Absolute index of bigram lw1,lw2 on disk */
123 t = FIRST_TG(model, b);
125 tginfo->tg = model->lm3g.trigrams + t;
127 /* Find #tg for bigram w1,w2 */
128 tginfo->n_tg = FIRST_TG(model, b + 1) - t;
130 else { /* No bigram w1,w2 */
136 /* Similar to find_bg */
138 find_tg(trigram_t * tg, int32 n, int32 w)
144 while (e - b > BINARY_SEARCH_THRESH) {
148 else if (tg[i].wid > w)
154 for (i = b; (i < e) && (tg[i].wid != w); i++);
155 return ((i < e) ? i : -1);
159 lm3g_tg_score(NGRAM_MODEL_TYPE *model, int32 lw1,
160 int32 lw2, int32 lw3, int32 *n_used)
162 ngram_model_t *base = &model->base;
165 tginfo_t *tginfo, *prev_tginfo;
167 if ((base->n < 3) || (lw1 < 0) || (lw2 < 0))
168 return (lm3g_bg_score(model, lw2, lw3, n_used));
171 for (tginfo = model->lm3g.tginfo[lw2]; tginfo; tginfo = tginfo->next) {
172 if (tginfo->w1 == lw1)
174 prev_tginfo = tginfo;
178 load_tginfo(model, lw1, lw2);
179 tginfo = model->lm3g.tginfo[lw2];
181 else if (prev_tginfo) {
182 prev_tginfo->next = tginfo->next;
183 tginfo->next = model->lm3g.tginfo[lw2];
184 model->lm3g.tginfo[lw2] = tginfo;
189 /* Trigrams for w1,w2 now pointed to by tginfo */
192 if ((i = find_tg(tg, n, lw3)) >= 0) {
193 /* Access mode = trigram */
195 score = model->lm3g.prob3[tg[i].prob3].l;
198 score = tginfo->bowt + lm3g_bg_score(model, lw2, lw3, n_used);
205 lm3g_template_score(ngram_model_t *base, int32 wid,
206 int32 *history, int32 n_hist,
209 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
212 /* Access mode: unigram */
214 return model->lm3g.unigrams[wid].prob1.l;
216 return lm3g_bg_score(model, history[0], wid, n_used);
219 /* Anything greater than 2 is the same as a trigram for now. */
220 return lm3g_tg_score(model, history[1], history[0], wid, n_used);
225 lm3g_template_raw_score(ngram_model_t *base, int32 wid,
226 int32 *history, int32 n_hist,
229 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
234 /* Access mode: unigram */
236 /* Undo insertion penalty. */
237 score = model->lm3g.unigrams[wid].prob1.l - base->log_wip;
238 /* Undo language weight. */
239 score = (int32)(score / base->lw);
240 /* Undo unigram interpolation */
241 if (strcmp(base->word_str[wid], "<s>") != 0) { /* FIXME: configurable start_sym */
242 score = logmath_log(base->lmath,
243 logmath_exp(base->lmath, score)
244 - logmath_exp(base->lmath,
245 base->log_uniform + base->log_uniform_weight));
249 score = lm3g_bg_score(model, history[0], wid, n_used);
253 /* Anything greater than 2 is the same as a trigram for now. */
254 score = lm3g_tg_score(model, history[1], history[0], wid, n_used);
257 /* FIXME (maybe): This doesn't undo unigram weighting in backoff cases. */
258 return (int32)((score - base->log_wip) / base->lw);
262 lm3g_template_add_ug(ngram_model_t *base,
263 int32 wid, int32 lweight)
265 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
266 return lm3g_add_ug(base, &model->lm3g, wid, lweight);
270 lm3g_template_flush(ngram_model_t *base)
272 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
273 lm3g_tginfo_reset(base, &model->lm3g);
276 typedef struct lm3g_iter_s {
283 static ngram_iter_t *
284 lm3g_template_iter(ngram_model_t *base, int32 wid,
285 int32 *history, int32 n_hist)
287 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
288 lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
290 ngram_iter_init((ngram_iter_t *)itor, base, n_hist, FALSE);
293 /* Unigram is the easiest. */
294 itor->ug = model->lm3g.unigrams + wid;
295 return (ngram_iter_t *)itor;
297 else if (n_hist == 1) {
299 /* Find the bigram, as in bg_score above (duplicate code...) */
300 itor->ug = model->lm3g.unigrams + history[0];
301 b = FIRST_BG(model, history[0]);
302 n = FIRST_BG(model, history[0] + 1) - b;
303 itor->bg = model->lm3g.bigrams + b;
304 /* If no such bigram exists then fail. */
305 if ((i = find_bg(itor->bg, n, wid)) < 0) {
306 ngram_iter_free((ngram_iter_t *)itor);
310 return (ngram_iter_t *)itor;
312 else if (n_hist == 2) {
314 tginfo_t *tginfo, *prev_tginfo;
315 /* Find the trigram, as in tg_score above (duplicate code...) */
316 itor->ug = model->lm3g.unigrams + history[1];
318 for (tginfo = model->lm3g.tginfo[history[0]];
319 tginfo; tginfo = tginfo->next) {
320 if (tginfo->w1 == history[1])
322 prev_tginfo = tginfo;
326 load_tginfo(model, history[1], history[0]);
327 tginfo = model->lm3g.tginfo[history[0]];
329 else if (prev_tginfo) {
330 prev_tginfo->next = tginfo->next;
331 tginfo->next = model->lm3g.tginfo[history[0]];
332 model->lm3g.tginfo[history[0]] = tginfo;
337 /* Trigrams for w1,w2 now pointed to by tginfo */
339 itor->tg = tginfo->tg;
340 if ((i = find_tg(itor->tg, n, wid)) >= 0) {
342 /* Now advance the bigram pointer accordingly. FIXME:
343 * Note that we actually already found the relevant bigram
345 itor->bg = model->lm3g.bigrams;
346 while (FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))
347 <= (itor->tg - model->lm3g.trigrams))
349 return (ngram_iter_t *)itor;
352 ngram_iter_free((ngram_iter_t *)itor);
353 return (ngram_iter_t *)NULL;
357 /* Should not happen. */
358 assert(n_hist == 0); /* Guaranteed to fail. */
359 ngram_iter_free((ngram_iter_t *)itor);
364 static ngram_iter_t *
365 lm3g_template_mgrams(ngram_model_t *base, int m)
367 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
368 lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
369 ngram_iter_init((ngram_iter_t *)itor, base, m, FALSE);
371 itor->ug = model->lm3g.unigrams;
372 itor->bg = model->lm3g.bigrams;
373 itor->tg = model->lm3g.trigrams;
375 /* Advance bigram pointer to match first trigram. */
376 if (m > 1 && base->n_counts[1] > 1) {
377 while (FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))
378 <= (itor->tg - model->lm3g.trigrams))
382 /* Advance unigram pointer to match first bigram. */
383 if (m > 0 && base->n_counts[0] > 1) {
384 while (itor->ug[1].bigrams <= (itor->bg - model->lm3g.bigrams))
388 return (ngram_iter_t *)itor;
391 static ngram_iter_t *
392 lm3g_template_successors(ngram_iter_t *bitor)
394 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)bitor->model;
395 lm3g_iter_t *from = (lm3g_iter_t *)bitor;
396 lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
401 /* Next itor bigrams is the same as this itor bigram or
402 itor bigrams is more than total count. This means no successors */
403 if (((itor->ug + 1) - model->lm3g.unigrams < bitor->model->n_counts[0] &&
404 itor->ug->bigrams == (itor->ug + 1)->bigrams) ||
405 itor->ug->bigrams == bitor->model->n_counts[1])
408 /* Start iterating from first bigram successor of from->ug. */
409 itor->bg = model->lm3g.bigrams + itor->ug->bigrams;
414 /* This indicates no successors */
415 if (((itor->bg + 1) - model->lm3g.bigrams < bitor->model->n_counts[1] &&
416 FIRST_TG (model, itor->bg - model->lm3g.bigrams) ==
417 FIRST_TG (model, (itor->bg + 1) - model->lm3g.bigrams)) ||
418 FIRST_TG (model, itor->bg - model->lm3g.bigrams) == bitor->model->n_counts[2])
421 /* Start iterating from first trigram successor of from->bg. */
422 itor->tg = (model->lm3g.trigrams
423 + FIRST_TG(model, (itor->bg - model->lm3g.bigrams)));
425 printf("%s %s => %d (%s)\n",
426 model->base.word_str[itor->ug - model->lm3g.unigrams],
427 model->base.word_str[itor->bg->wid],
428 FIRST_TG(model, (itor->bg - model->lm3g.bigrams)),
429 model->base.word_str[itor->tg->wid]);
438 ngram_iter_init((ngram_iter_t *)itor, bitor->model, bitor->m + 1, TRUE);
439 return (ngram_iter_t *)itor;
446 lm3g_template_iter_get(ngram_iter_t *base,
447 int32 *out_score, int32 *out_bowt)
449 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base->model;
450 lm3g_iter_t *itor = (lm3g_iter_t *)base;
452 base->wids[0] = itor->ug - model->lm3g.unigrams;
453 if (itor->bg) base->wids[1] = itor->bg->wid;
454 if (itor->tg) base->wids[2] = itor->tg->wid;
456 printf("itor_get: %d %d %d\n", base->wids[0], base->wids[1], base->wids[2]);
461 *out_score = itor->ug->prob1.l;
462 *out_bowt = itor->ug->bo_wt1.l;
465 *out_score = model->lm3g.prob2[itor->bg->prob2].l;
466 if (model->lm3g.bo_wt2)
467 *out_bowt = model->lm3g.bo_wt2[itor->bg->bo_wt2].l;
472 *out_score = model->lm3g.prob3[itor->tg->prob3].l;
475 default: /* Should not happen. */
481 static ngram_iter_t *
482 lm3g_template_iter_next(ngram_iter_t *base)
484 NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base->model;
485 lm3g_iter_t *itor = (lm3g_iter_t *)base;
490 /* Check for end condition. */
491 if (itor->ug - model->lm3g.unigrams >= base->model->n_counts[0])
496 /* Check for end condition. */
497 if (itor->bg - model->lm3g.bigrams >= base->model->n_counts[1])
499 /* Advance unigram pointer if necessary in order to get one
500 * that points to this bigram. */
501 while (itor->bg - model->lm3g.bigrams >= itor->ug[1].bigrams) {
502 /* Stop if this is a successor iterator, since we don't
503 * want a new unigram. */
507 if (itor->ug == model->lm3g.unigrams + base->model->n_counts[0]) {
508 E_ERROR("Bigram %d has no valid unigram parent\n",
509 itor->bg - model->lm3g.bigrams);
516 /* Check for end condition. */
517 if (itor->tg - model->lm3g.trigrams >= base->model->n_counts[2])
519 /* Advance bigram pointer if necessary. */
520 while (itor->tg - model->lm3g.trigrams >=
521 FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))) {
525 if (itor->bg == model->lm3g.bigrams + base->model->n_counts[1]) {
526 E_ERROR("Trigram %d has no valid bigram parent\n",
527 itor->tg - model->lm3g.trigrams);
532 /* Advance unigram pointer if necessary. */
533 while (itor->bg - model->lm3g.bigrams >= itor->ug[1].bigrams) {
535 if (itor->ug == model->lm3g.unigrams + base->model->n_counts[0]) {
536 E_ERROR("Trigram %d has no valid unigram parent\n",
537 itor->tg - model->lm3g.trigrams);
542 default: /* Should not happen. */
546 return (ngram_iter_t *)itor;
548 ngram_iter_free(base);
553 lm3g_template_iter_free(ngram_iter_t *base)