fixes estimate k mixture model
authorPeng Wu <alexepico@gmail.com>
Wed, 8 Jun 2011 07:57:36 +0000 (15:57 +0800)
committerPeng Wu <alexepico@gmail.com>
Wed, 8 Jun 2011 07:57:36 +0000 (15:57 +0800)
utils/training/estimate_k_mixture_model.cpp

index 66c57f68ae54c9b9fcc7d9975d84723f0ca87e52..ee9be61687235e238e29d3715f8944f1219e0aea 100644 (file)
@@ -35,6 +35,9 @@ parameter_t compute_interpolation(KMixtureModelSingleGram * deleted_bigram,
     parameter_t lambda = 0, next_lambda = 0.6;
     parameter_t epsilon = 0.001;
 
+    KMixtureModelMagicHeader magic_header;
+    assert(unigram->get_magic_header(magic_header));
+
     while (fabs(lambda - next_lambda) > epsilon){
         lambda = next_lambda;
         next_lambda = 0;
@@ -52,23 +55,21 @@ parameter_t compute_interpolation(KMixtureModelSingleGram * deleted_bigram,
 
             {
                 parameter_t elem_poss = 0;
-                KMixtureModelArrayItem item;
-                if ( bigram && bigram->get_array_item(token, item) ){
-                    KMixtureModelArrayHeader header;
-                    assert(bigram->get_array_header(header));
-                    assert(0 != header.m_WC);
-                    elem_poss = item.m_WC / (parameter_t) header.m_WC;
+                KMixtureModelArrayHeader array_header;
+                KMixtureModelArrayItem array_item;
+                if ( bigram && bigram->get_array_item(token, array_item) ){
+                    assert(bigram->get_array_header(array_header));
+                    assert(0 != array_header.m_WC);
+                    elem_poss = array_item.m_WC / (parameter_t) array_header.m_WC;
                 }
                 numerator = lambda * elem_poss;
             }
 
             {
                 parameter_t elem_poss = 0;
-                KMixtureModelMagicHeader magic_header;
                 KMixtureModelArrayHeader array_header;
                 if (unigram->get_array_header(token, array_header)){
                     /* Note: optimize here? */
-                    assert(unigram->get_magic_header(magic_header));
                     assert(0 != magic_header.m_WC);
                     elem_poss = array_header.m_WC / (parameter_t) magic_header.m_WC;
                 }
@@ -120,6 +121,9 @@ int main(int argc, char * argv[]){
     }
 
     /* TODO: magic header signature check here. */
+    KMixtureModelBigram unigram(K_MIXTURE_MODEL_MAGIC_NUMBER);
+    unigram.attach(bigram_filename, ATTACH_READONLY);
+
     KMixtureModelBigram bigram(K_MIXTURE_MODEL_MAGIC_NUMBER);
     bigram.attach(bigram_filename, ATTACH_READONLY);
 
@@ -140,12 +144,20 @@ int main(int argc, char * argv[]){
         KMixtureModelSingleGram * deleted_single_gram = NULL;
         deleted_bigram.load(*token, deleted_single_gram);
 
-        parameter_t lambda = compute_interpolation(deleted_single_gram, &bigram, single_gram);
+        KMixtureModelArrayHeader array_header;
+        if (single_gram)
+            assert(single_gram->get_array_header(array_header));
+        KMixtureModelArrayHeader deleted_array_header;
+        assert(deleted_single_gram->get_array_header(deleted_array_header));
+
+        if ( 0 != array_header.m_WC && 0 != deleted_array_header.m_WC ) {
+            parameter_t lambda = compute_interpolation(deleted_single_gram, &unigram, single_gram);
 
-        printf("lambda:%f\n", lambda);
+            printf("lambda:%f\n", lambda);
 
-        lambda_sum += lambda;
-        lambda_count ++;
+            lambda_sum += lambda;
+            lambda_count ++;
+        }
 
         if (single_gram)
             delete single_gram;