Initial import to Tizen
[profile/ivi/sphinxbase.git] / src / libsphinxbase / lm / ngram_model_set.c
1 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /* ====================================================================
3  * Copyright (c) 2008 Carnegie Mellon University.  All rights
4  * reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer. 
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  *
18  * This work was supported in part by funding from the Defense Advanced 
19  * Research Projects Agency and the National Science Foundation of the 
20  * United States of America, and the CMU Sphinx Speech Consortium.
21  *
22  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
23  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
24  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
25  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
26  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
28  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
29  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
30  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
31  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
32  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  *
34  * ====================================================================
35  *
36  */
37 /**
38  * @file ngram_model_set.c Set of language models.
39  * @author David Huggins-Daines <dhuggins@cs.cmu.edu>
40  */
41
42 #include <string.h>
43 #include <stdlib.h>
44
45 #include "sphinxbase/err.h"
46 #include "sphinxbase/ckd_alloc.h"
47 #include "sphinxbase/strfuncs.h"
48 #include "sphinxbase/filename.h"
49
50 #include "ngram_model_set.h"
51
52 static ngram_funcs_t ngram_model_set_funcs;
53
54 static int
55 my_compare(const void *a, const void *b)
56 {
57     /* Make sure <UNK> floats to the beginning. */
58     if (strcmp(*(char * const *)a, "<UNK>") == 0)
59         return -1;
60     else if (strcmp(*(char * const *)b, "<UNK>") == 0)
61         return 1;
62     else
63         return strcmp(*(char * const *)a, *(char * const *)b);
64 }
65
66 static void
67 build_widmap(ngram_model_t *base, logmath_t *lmath, int32 n)
68 {
69     ngram_model_set_t *set = (ngram_model_set_t *)base;
70     ngram_model_t **models = set->lms;
71     hash_table_t *vocab;
72     glist_t hlist;
73     gnode_t *gn;
74     int32 i;
75
76     /* Construct a merged vocabulary and a set of word-ID mappings. */
77     vocab = hash_table_new(models[0]->n_words, FALSE);
78     /* Create the set of merged words. */
79     for (i = 0; i < set->n_models; ++i) {
80         int32 j;
81         for (j = 0; j < models[i]->n_words; ++j) {
82             /* Ignore collisions. */
83             (void)hash_table_enter_int32(vocab, models[i]->word_str[j], j);
84         }
85     }
86     /* Create the array of words, then sort it. */
87     if (hash_table_lookup(vocab, "<UNK>", NULL) != 0)
88         (void)hash_table_enter_int32(vocab, "<UNK>", 0);
89     /* Now we know the number of unigrams, initialize the base model. */
90     ngram_model_init(base, &ngram_model_set_funcs, lmath, n, hash_table_inuse(vocab));
91     base->writable = FALSE; /* We will reuse the pointers from the submodels. */
92     i = 0;
93     hlist = hash_table_tolist(vocab, NULL);
94     for (gn = hlist; gn; gn = gnode_next(gn)) {
95         hash_entry_t *ent = gnode_ptr(gn);
96         base->word_str[i++] = (char *)ent->key;
97     }
98     glist_free(hlist);
99     qsort(base->word_str, base->n_words, sizeof(*base->word_str), my_compare);
100
101     /* Now create the word ID mappings. */
102     if (set->widmap)
103         ckd_free_2d((void **)set->widmap);
104     set->widmap = (int32 **) ckd_calloc_2d(base->n_words, set->n_models,
105                                            sizeof(**set->widmap));
106     for (i = 0; i < base->n_words; ++i) {
107         int32 j;
108         /* Also create the master wid mapping. */
109         (void)hash_table_enter_int32(base->wid, base->word_str[i], i);
110         /* printf("%s: %d => ", base->word_str[i], i); */
111         for (j = 0; j < set->n_models; ++j) {
112             set->widmap[i][j] = ngram_wid(models[j], base->word_str[i]);
113             /* printf("%d ", set->widmap[i][j]); */
114         }
115         /* printf("\n"); */
116     }
117     hash_table_free(vocab);
118 }
119
120 ngram_model_t *
121 ngram_model_set_init(cmd_ln_t *config,
122                      ngram_model_t **models,
123                      char **names,
124                      const float32 *weights,
125                      int32 n_models)
126 {
127     ngram_model_set_t *model;
128     ngram_model_t *base;
129     logmath_t *lmath;
130     int32 i, n;
131
132     if (n_models == 0) /* WTF */
133         return NULL;
134
135     /* Do consistency checking on the models.  They must all use the
136      * same logbase and shift. */
137     lmath = models[0]->lmath;
138     for (i = 1; i < n_models; ++i) {
139         if (logmath_get_base(models[i]->lmath) != logmath_get_base(lmath)
140             || logmath_get_shift(models[i]->lmath) != logmath_get_shift(lmath)) {
141             E_ERROR("Log-math parameters don't match, will not create LM set\n");
142             return NULL;
143         }
144     }
145
146     /* Allocate the combined model, initialize it. */
147     model = ckd_calloc(1, sizeof(*model));
148     base = &model->base;
149     model->n_models = n_models;
150     model->lms = ckd_calloc(n_models, sizeof(*model->lms));
151     model->names = ckd_calloc(n_models, sizeof(*model->names));
152     /* Initialize weights to a uniform distribution */
153     model->lweights = ckd_calloc(n_models, sizeof(*model->lweights));
154     {
155         int32 uniform = logmath_log(lmath, 1.0/n_models);
156         for (i = 0; i < n_models; ++i)
157             model->lweights[i] = uniform;
158     }
159     /* Default to interpolate if weights were given. */
160     if (weights)
161         model->cur = -1;
162
163     n = 0;
164     for (i = 0; i < n_models; ++i) {
165         model->lms[i] = models[i];
166         model->names[i] = ckd_salloc(names[i]);
167         if (weights)
168             model->lweights[i] = logmath_log(lmath, weights[i]);
169         /* N is the maximum of all merged models. */
170         if (models[i]->n > n)
171             n = models[i]->n;
172     }
173     /* Allocate the history mapping table. */
174     model->maphist = ckd_calloc(n - 1, sizeof(*model->maphist));
175
176     /* Now build the word-ID mapping and merged vocabulary. */
177     build_widmap(base, lmath, n);
178     return base;
179 }
180
181 ngram_model_t *
182 ngram_model_set_read(cmd_ln_t *config,
183                      const char *lmctlfile,
184                      logmath_t *lmath)
185 {
186     FILE *ctlfp;
187     glist_t lms = NULL;
188     glist_t lmnames = NULL;
189     __BIGSTACKVARIABLE__ char str[1024];
190     ngram_model_t *set = NULL;
191     hash_table_t *classes;
192     char *basedir, *c;
193
194     /* Read all the class definition files to accumulate a mapping of
195      * classnames to definitions. */
196     classes = hash_table_new(0, FALSE);
197     if ((ctlfp = fopen(lmctlfile, "r")) == NULL) {
198         E_ERROR_SYSTEM("Failed to open %s", lmctlfile);
199         return NULL;
200     }
201
202     /* Try to find the base directory to append to relative paths in
203      * the lmctl file. */
204     if ((c = strrchr(lmctlfile, '/')) || (c = strrchr(lmctlfile, '\\'))) {
205         /* Include the trailing slash. */
206         basedir = ckd_calloc(c - lmctlfile + 2, 1);
207         memcpy(basedir, lmctlfile, c - lmctlfile + 1);
208     }
209     else {
210         basedir = NULL;
211     }
212     E_INFO("Reading LM control file '%s'\n", lmctlfile);
213     if (basedir)
214         E_INFO("Will prepend '%s' to unqualified paths\n", basedir);
215
216     if (fscanf(ctlfp, "%1023s", str) == 1) {
217         if (strcmp(str, "{") == 0) {
218             /* Load LMclass files */
219             while ((fscanf(ctlfp, "%1023s", str) == 1)
220                    && (strcmp(str, "}") != 0)) {
221                 char *deffile;
222                 if (basedir && !path_is_absolute(str))
223                     deffile = string_join(basedir, str, NULL);
224                 else
225                     deffile = ckd_salloc(str);
226                 E_INFO("Reading classdef from '%s'\n", deffile);
227                 if (read_classdef_file(classes, deffile) < 0) {
228                     ckd_free(deffile);
229                     goto error_out;
230                 }
231                 ckd_free(deffile);
232             }
233
234             if (strcmp(str, "}") != 0) {
235                 E_ERROR("Unexpected EOF in %s\n", lmctlfile);
236                 goto error_out;
237             }
238
239             /* This might be the first LM name. */
240             if (fscanf(ctlfp, "%1023s", str) != 1)
241                 str[0] = '\0';
242         }
243     }
244     else
245         str[0] = '\0';
246
247     /* Read in one LM at a time and add classes to them as necessary. */
248     while (str[0] != '\0') {
249         char *lmfile;
250         ngram_model_t *lm;
251
252         if (basedir && str[0] != '/' && str[0] != '\\')
253             lmfile = string_join(basedir, str, NULL);
254         else
255             lmfile = ckd_salloc(str);
256         E_INFO("Reading lm from '%s'\n", lmfile);
257         lm = ngram_model_read(config, lmfile, NGRAM_AUTO, lmath);
258         if (lm == NULL) {
259             ckd_free(lmfile);
260             goto error_out;
261         }
262         if (fscanf(ctlfp, "%1023s", str) != 1) {
263             E_ERROR("LMname missing after LMFileName '%s'\n", lmfile);
264             ckd_free(lmfile);
265             goto error_out;
266         }
267         ckd_free(lmfile);
268         lms = glist_add_ptr(lms, lm);
269         lmnames = glist_add_ptr(lmnames, ckd_salloc(str));
270
271         if (fscanf(ctlfp, "%1023s", str) == 1) {
272             if (strcmp(str, "{") == 0) {
273                 /* LM uses classes; read their names */
274                 while ((fscanf(ctlfp, "%1023s", str) == 1) &&
275                        (strcmp(str, "}") != 0)) {
276                     void *val;
277                     classdef_t *classdef;
278
279                     if (hash_table_lookup(classes, str, &val) == -1) {
280                         E_ERROR("Unknown class %s in control file\n", str);
281                         goto error_out;
282                     }
283                     classdef = val;
284                     if (ngram_model_add_class(lm, str, 1.0,
285                                               classdef->words, classdef->weights,
286                                               classdef->n_words) < 0) {
287                         goto error_out;
288                     }
289                     E_INFO("Added class %s containing %d words\n",
290                            str, classdef->n_words);
291                 }
292                 if (strcmp(str, "}") != 0) {
293                     E_ERROR("Unexpected EOF in %s\n", lmctlfile);
294                     goto error_out;
295                 }
296                 if (fscanf(ctlfp, "%1023s", str) != 1)
297                     str[0] = '\0';
298             }
299         }
300         else
301             str[0] = '\0';
302     }
303     fclose(ctlfp);
304
305     /* Now construct arrays out of lms and lmnames, and build an
306      * ngram_model_set. */
307     lms = glist_reverse(lms);
308     lmnames = glist_reverse(lmnames);
309     {
310         int32 n_models;
311         ngram_model_t **lm_array;
312         char **name_array;
313         gnode_t *lm_node, *name_node;
314         int32 i;
315
316         n_models = glist_count(lms);
317         lm_array = ckd_calloc(n_models, sizeof(*lm_array));
318         name_array = ckd_calloc(n_models, sizeof(*name_array));
319         lm_node = lms;
320         name_node = lmnames;
321         for (i = 0; i < n_models; ++i) {
322             lm_array[i] = gnode_ptr(lm_node);
323             name_array[i] = gnode_ptr(name_node);
324             lm_node = gnode_next(lm_node);
325             name_node = gnode_next(name_node);
326         }
327         set = ngram_model_set_init(config, lm_array, name_array,
328                                    NULL, n_models);
329         ckd_free(lm_array);
330         ckd_free(name_array);
331     }
332 error_out:
333     {
334         gnode_t *gn;
335         glist_t hlist;
336
337         if (set == NULL) {
338             for (gn = lms; gn; gn = gnode_next(gn)) {
339                 ngram_model_free(gnode_ptr(gn));
340             }
341         }
342         glist_free(lms);
343         for (gn = lmnames; gn; gn = gnode_next(gn)) {
344             ckd_free(gnode_ptr(gn));
345         }
346         glist_free(lmnames);
347         hlist = hash_table_tolist(classes, NULL);
348         for (gn = hlist; gn; gn = gnode_next(gn)) {
349             hash_entry_t *he = gnode_ptr(gn);
350             ckd_free((char *)he->key);
351             classdef_free(he->val);
352         }
353         glist_free(hlist);
354         hash_table_free(classes);
355         ckd_free(basedir);
356     }
357     return set;
358 }
359
360 int32
361 ngram_model_set_count(ngram_model_t *base)
362 {
363     ngram_model_set_t *set = (ngram_model_set_t *)base;
364     return set->n_models;
365 }
366
367 ngram_model_set_iter_t *
368 ngram_model_set_iter(ngram_model_t *base)
369 {
370     ngram_model_set_t *set = (ngram_model_set_t *)base;
371     ngram_model_set_iter_t *itor;
372
373     if (set == NULL || set->n_models == 0)
374         return NULL;
375     itor = ckd_calloc(1, sizeof(*itor));
376     itor->set = set;
377     return itor;
378 }
379
380 ngram_model_set_iter_t *
381 ngram_model_set_iter_next(ngram_model_set_iter_t *itor)
382 {
383     if (++itor->cur == itor->set->n_models) {
384         ngram_model_set_iter_free(itor);
385         return NULL;
386     }
387     return itor;
388 }
389
390 void
391 ngram_model_set_iter_free(ngram_model_set_iter_t *itor)
392 {
393     ckd_free(itor);
394 }
395
396 ngram_model_t *
397 ngram_model_set_iter_model(ngram_model_set_iter_t *itor,
398                            char const **lmname)
399 {
400     if (lmname) *lmname = itor->set->names[itor->cur];
401     return itor->set->lms[itor->cur];
402 }
403
404 ngram_model_t *
405 ngram_model_set_lookup(ngram_model_t *base,
406                        const char *name)
407 {
408     ngram_model_set_t *set = (ngram_model_set_t *)base;
409     int32 i;
410
411     if (name == NULL) {
412         if (set->cur == -1)
413             return NULL;
414         else
415             return set->lms[set->cur];
416     }
417
418     /* There probably won't be very many submodels. */
419     for (i = 0; i < set->n_models; ++i)
420         if (0 == strcmp(set->names[i], name))
421             break;
422     if (i == set->n_models)
423         return NULL;
424     return set->lms[i];
425 }
426
427 ngram_model_t *
428 ngram_model_set_select(ngram_model_t *base,
429                        const char *name)
430 {
431     ngram_model_set_t *set = (ngram_model_set_t *)base;
432     int32 i;
433
434     /* There probably won't be very many submodels. */
435     for (i = 0; i < set->n_models; ++i)
436         if (0 == strcmp(set->names[i], name))
437             break;
438     if (i == set->n_models)
439         return NULL;
440     set->cur = i;
441     return set->lms[set->cur];
442 }
443
444 const char *
445 ngram_model_set_current(ngram_model_t *base)
446 {
447     ngram_model_set_t *set = (ngram_model_set_t *)base;
448
449     if (set->cur == -1)
450         return NULL;
451     else
452         return set->names[set->cur];
453 }
454
455 int32
456 ngram_model_set_current_wid(ngram_model_t *base,
457                             int32 set_wid)
458 {
459     ngram_model_set_t *set = (ngram_model_set_t *)base;
460
461     if (set->cur == -1 || set_wid >= base->n_words)
462         return NGRAM_INVALID_WID;
463     else
464         return set->widmap[set_wid][set->cur];
465 }
466
467 int32
468 ngram_model_set_known_wid(ngram_model_t *base,
469                           int32 set_wid)
470 {
471     ngram_model_set_t *set = (ngram_model_set_t *)base;
472
473     if (set_wid >= base->n_words)
474         return FALSE;
475     else if (set->cur == -1) {
476         int32 i;
477         for (i = 0; i < set->n_models; ++i) {
478             if (set->widmap[set_wid][i] != ngram_unknown_wid(set->lms[i]))
479                 return TRUE;
480         }
481         return FALSE;
482     }
483     else
484         return (set->widmap[set_wid][set->cur]
485                 != ngram_unknown_wid(set->lms[set->cur]));
486 }
487
488 ngram_model_t *
489 ngram_model_set_interp(ngram_model_t *base,
490                        const char **names,
491                        const float32 *weights)
492 {
493     ngram_model_set_t *set = (ngram_model_set_t *)base;
494
495     /* If we have a set of weights here, then set them. */
496     if (names && weights) {
497         int32 i, j;
498
499         /* We hope there aren't many models. */
500         for (i = 0; i < set->n_models; ++i) {
501             for (j = 0; j < set->n_models; ++j)
502                 if (0 == strcmp(names[i], set->names[j]))
503                     break;
504             if (j == set->n_models) {
505                 E_ERROR("Unknown LM name %s\n", names[i]);
506                 return NULL;
507             }
508             set->lweights[j] = logmath_log(base->lmath, weights[i]);
509         }
510     }
511     else if (weights) {
512         memcpy(set->lweights, weights, set->n_models * sizeof(*set->lweights));
513     }
514     /* Otherwise just enable existing weights. */
515     set->cur = -1;
516     return base;
517 }
518
519 ngram_model_t *
520 ngram_model_set_add(ngram_model_t *base,
521                     ngram_model_t *model,
522                     const char *name,
523                     float32 weight,
524                     int reuse_widmap)
525                     
526 {
527     ngram_model_set_t *set = (ngram_model_set_t *)base;
528     float32 fprob;
529     int32 scale, i;
530
531     /* Add it to the array of lms. */
532     ++set->n_models;
533     set->lms = ckd_realloc(set->lms, set->n_models * sizeof(*set->lms));
534     set->lms[set->n_models - 1] = model;
535     set->names = ckd_realloc(set->names, set->n_models * sizeof(*set->names));
536     set->names[set->n_models - 1] = ckd_salloc(name);
537     /* Expand the history mapping table if necessary. */
538     if (model->n > base->n) {
539         base->n = model->n;
540         set->maphist = ckd_realloc(set->maphist,
541                                    (model->n - 1) * sizeof(*set->maphist));
542     }
543
544     /* Renormalize the interpolation weights. */
545     fprob = weight * 1.0 / set->n_models;
546     set->lweights = ckd_realloc(set->lweights,
547                                 set->n_models * sizeof(*set->lweights));
548     set->lweights[set->n_models - 1] = logmath_log(base->lmath, fprob);
549     /* Now normalize everything else to fit it in.  This is
550      * accomplished by simply scaling all the other probabilities
551      * by (1-fprob). */
552     scale = logmath_log(base->lmath, 1.0 - fprob);
553     for (i = 0; i < set->n_models - 1; ++i)
554         set->lweights[i] += scale;
555
556     /* Reuse the old word ID mapping if requested. */
557     if (reuse_widmap) {
558         int32 **new_widmap;
559
560         /* Tack another column onto the widmap array. */
561         new_widmap = (int32 **)ckd_calloc_2d(base->n_words, set->n_models,
562                                              sizeof (**new_widmap));
563         for (i = 0; i < base->n_words; ++i) {
564             /* Copy all the existing mappings. */
565             memcpy(new_widmap[i], set->widmap[i],
566                    (set->n_models - 1) * sizeof(**new_widmap));
567             /* Create the new mapping. */
568             new_widmap[i][set->n_models-1] = ngram_wid(model, base->word_str[i]);
569         }
570         ckd_free_2d((void **)set->widmap);
571         set->widmap = new_widmap;
572     }
573     else {
574         build_widmap(base, base->lmath, base->n);
575     }
576     return model;
577 }
578
579 ngram_model_t *
580 ngram_model_set_remove(ngram_model_t *base,
581                        const char *name,
582                        int reuse_widmap)
583 {
584     ngram_model_set_t *set = (ngram_model_set_t *)base;
585     ngram_model_t *submodel;
586     int32 lmidx, scale, n, i;
587     float32 fprob;
588
589     for (lmidx = 0; lmidx < set->n_models; ++lmidx)
590         if (0 == strcmp(name, set->names[lmidx]))
591             break;
592     if (lmidx == set->n_models)
593         return NULL;
594     submodel = set->lms[lmidx];
595
596     /* Renormalize the interpolation weights by scaling them by
597      * 1/(1-fprob) */
598     fprob = logmath_exp(base->lmath, set->lweights[lmidx]);
599     scale = logmath_log(base->lmath, 1.0 - fprob);
600
601     /* Remove it from the array of lms, renormalize remaining weights,
602      * and recalcluate n. */
603     --set->n_models;
604     n = 0;
605     ckd_free(set->names[lmidx]);
606     set->names[lmidx] = NULL;
607     for (i = 0; i < set->n_models; ++i) {
608         if (i >= lmidx) {
609             set->lms[i] = set->lms[i+1];
610             set->names[i] = set->names[i+1];
611             set->lweights[i] = set->lweights[i+1];
612         }
613         set->lweights[i] -= scale;
614         if (set->lms[i]->n > n)
615             n = set->lms[i]->n;
616     }
617     /* There's no need to shrink these arrays. */
618     set->lms[set->n_models] = NULL;
619     set->lweights[set->n_models] = base->log_zero;
620     /* No need to shrink maphist either. */
621
622     /* Reuse the existing word ID mapping if requested. */
623     if (reuse_widmap) {
624         /* Just go through and shrink each row. */
625         for (i = 0; i < base->n_words; ++i) {
626             memmove(set->widmap[i] + lmidx, set->widmap[i] + lmidx + 1,
627                     (set->n_models - lmidx) * sizeof(**set->widmap));
628         }
629     }
630     else {
631         build_widmap(base, base->lmath, n);
632     }
633     return submodel;
634 }
635
636 void
637 ngram_model_set_map_words(ngram_model_t *base,
638                           const char **words,
639                           int32 n_words)
640 {
641     ngram_model_set_t *set = (ngram_model_set_t *)base;
642     int32 i;
643
644     /* Recreate the word mapping. */
645     if (base->writable) {
646         for (i = 0; i < base->n_words; ++i) {
647             ckd_free(base->word_str[i]);
648         }
649     }
650     ckd_free(base->word_str);
651     ckd_free_2d((void **)set->widmap);
652     base->writable = TRUE;
653     base->n_words = base->n_1g_alloc = n_words;
654     base->word_str = ckd_calloc(n_words, sizeof(*base->word_str));
655     set->widmap = (int32 **)ckd_calloc_2d(n_words, set->n_models, sizeof(**set->widmap));
656     hash_table_empty(base->wid);
657     for (i = 0; i < n_words; ++i) {
658         int32 j;
659         base->word_str[i] = ckd_salloc(words[i]);
660         (void)hash_table_enter_int32(base->wid, base->word_str[i], i);
661         for (j = 0; j < set->n_models; ++j) {
662             set->widmap[i][j] = ngram_wid(set->lms[j], base->word_str[i]);
663         }
664     }
665 }
666
667 static int
668 ngram_model_set_apply_weights(ngram_model_t *base, float32 lw,
669                               float32 wip, float32 uw)
670 {
671     ngram_model_set_t *set = (ngram_model_set_t *)base;
672     int32 i;
673
674     /* Apply weights to each sub-model. */
675     for (i = 0; i < set->n_models; ++i)
676         ngram_model_apply_weights(set->lms[i], lw, wip, uw);
677     return 0;
678 }
679
680 static int32
681 ngram_model_set_score(ngram_model_t *base, int32 wid,
682                       int32 *history, int32 n_hist,
683                       int32 *n_used)
684 {
685     ngram_model_set_t *set = (ngram_model_set_t *)base;
686     int32 mapwid;
687     int32 score;
688     int32 i;
689
690     /* Truncate the history. */
691     if (n_hist > base->n - 1)
692         n_hist = base->n - 1;
693
694     /* Interpolate if there is no current. */
695     if (set->cur == -1) {
696         score = base->log_zero;
697         for (i = 0; i < set->n_models; ++i) {
698             int32 j;
699             /* Map word and history IDs for each model. */
700             mapwid = set->widmap[wid][i];
701             for (j = 0; j < n_hist; ++j) {
702                 if (history[j] == NGRAM_INVALID_WID)
703                     set->maphist[j] = NGRAM_INVALID_WID;
704                 else
705                     set->maphist[j] = set->widmap[history[j]][i];
706             }
707             score = logmath_add(base->lmath, score,
708                                 set->lweights[i] + 
709                                 ngram_ng_score(set->lms[i],
710                                                mapwid, set->maphist, n_hist, n_used));
711         }
712     }
713     else {
714         int32 j;
715         /* Map word and history IDs (FIXME: do this in a function?) */
716         mapwid = set->widmap[wid][set->cur];
717         for (j = 0; j < n_hist; ++j) {
718             if (history[j] == NGRAM_INVALID_WID)
719                 set->maphist[j] = NGRAM_INVALID_WID;
720             else
721                 set->maphist[j] = set->widmap[history[j]][set->cur];
722         }
723         score = ngram_ng_score(set->lms[set->cur],
724                                mapwid, set->maphist, n_hist, n_used);
725     }
726
727     return score;
728 }
729
730 static int32
731 ngram_model_set_raw_score(ngram_model_t *base, int32 wid,
732                           int32 *history, int32 n_hist,
733                           int32 *n_used)
734 {
735     ngram_model_set_t *set = (ngram_model_set_t *)base;
736     int32 mapwid;
737     int32 score;
738     int32 i;
739
740     /* Truncate the history. */
741     if (n_hist > base->n - 1)
742         n_hist = base->n - 1;
743
744     /* Interpolate if there is no current. */
745     if (set->cur == -1) {
746         score = base->log_zero;
747         for (i = 0; i < set->n_models; ++i) {
748             int32 j;
749             /* Map word and history IDs for each model. */
750             mapwid = set->widmap[wid][i];
751             for (j = 0; j < n_hist; ++j) {
752                 if (history[j] == NGRAM_INVALID_WID)
753                     set->maphist[j] = NGRAM_INVALID_WID;
754                 else
755                     set->maphist[j] = set->widmap[history[j]][i];
756             }
757             score = logmath_add(base->lmath, score,
758                                 set->lweights[i] + 
759                                 ngram_ng_prob(set->lms[i],
760                                               mapwid, set->maphist, n_hist, n_used));
761         }
762     }
763     else {
764         int32 j;
765         /* Map word and history IDs (FIXME: do this in a function?) */
766         mapwid = set->widmap[wid][set->cur];
767         for (j = 0; j < n_hist; ++j) {
768             if (history[j] == NGRAM_INVALID_WID)
769                 set->maphist[j] = NGRAM_INVALID_WID;
770             else
771                 set->maphist[j] = set->widmap[history[j]][set->cur];
772         }
773         score = ngram_ng_prob(set->lms[set->cur],
774                               mapwid, set->maphist, n_hist, n_used);
775     }
776
777     return score;
778 }
779
780 static int32
781 ngram_model_set_add_ug(ngram_model_t *base,
782                        int32 wid, int32 lweight)
783 {
784     ngram_model_set_t *set = (ngram_model_set_t *)base;
785     int32 *newwid;
786     int32 i, prob;
787
788     /* At this point the word has already been added to the master
789        model and we have a new word ID for it.  Add it to active
790        submodels and track the word IDs. */
791     newwid = ckd_calloc(set->n_models, sizeof(*newwid));
792     prob = base->log_zero;
793     for (i = 0; i < set->n_models; ++i) {
794         int32 wprob, n_hist;
795
796         /* Only add to active models. */
797         if (set->cur == -1 || set->cur == i) {
798             /* Did this word already exist? */
799             newwid[i] = ngram_wid(set->lms[i], base->word_str[wid]);
800             if (newwid[i] == NGRAM_INVALID_WID) {
801                 /* Add it to the submodel. */
802                 newwid[i] = ngram_model_add_word(set->lms[i], base->word_str[wid],
803                                                  logmath_exp(base->lmath, lweight));
804                 if (newwid[i] == NGRAM_INVALID_WID) {
805                     ckd_free(newwid);
806                     return base->log_zero;
807                 }
808             }
809             /* Now get the unigram probability for the new word and either
810              * interpolate it or use it (if this is the current model). */
811             wprob = ngram_ng_prob(set->lms[i], newwid[i], NULL, 0, &n_hist);
812             if (set->cur == i)
813                 prob = wprob;
814             else if (set->cur == -1)
815                 prob = logmath_add(base->lmath, prob, set->lweights[i] + wprob);
816         }
817         else {
818             newwid[i] = NGRAM_INVALID_WID;
819         }
820     }
821     /* Okay we have the word IDs for this in all the submodels.  Now
822        do some complicated memory mangling to add this to the
823        widmap. */
824     set->widmap = ckd_realloc(set->widmap, base->n_words * sizeof(*set->widmap));
825     set->widmap[0] = ckd_realloc(set->widmap[0],
826                                  base->n_words
827                                  * set->n_models
828                                  * sizeof(**set->widmap));
829     for (i = 0; i < base->n_words; ++i)
830         set->widmap[i] = set->widmap[0] + i * set->n_models;
831     memcpy(set->widmap[wid], newwid, set->n_models * sizeof(*newwid));
832     ckd_free(newwid);
833     return prob;
834 }
835
836 static void
837 ngram_model_set_free(ngram_model_t *base)
838 {
839     ngram_model_set_t *set = (ngram_model_set_t *)base;
840     int32 i;
841
842     for (i = 0; i < set->n_models; ++i)
843         ngram_model_free(set->lms[i]);
844     ckd_free(set->lms);
845     for (i = 0; i < set->n_models; ++i)
846         ckd_free(set->names[i]);
847     ckd_free(set->names);
848     ckd_free(set->lweights);
849     ckd_free(set->maphist);
850     ckd_free_2d((void **)set->widmap);
851 }
852
853 static void
854 ngram_model_set_flush(ngram_model_t *base)
855 {
856     ngram_model_set_t *set = (ngram_model_set_t *)base;
857     int32 i;
858
859     for (i = 0; i < set->n_models; ++i)
860         ngram_model_flush(set->lms[i]);
861 }
862
863 static ngram_funcs_t ngram_model_set_funcs = {
864     ngram_model_set_free,          /* free */
865     ngram_model_set_apply_weights, /* apply_weights */
866     ngram_model_set_score,         /* score */
867     ngram_model_set_raw_score,     /* raw_score */
868     ngram_model_set_add_ug,        /* add_ug */
869     ngram_model_set_flush          /* flush */
870 };