fixes unigram in k mixture model
authorPeng Wu <alexepico@gmail.com>
Thu, 23 Jun 2011 05:29:06 +0000 (13:29 +0800)
committerPeng Wu <alexepico@gmail.com>
Thu, 23 Jun 2011 05:29:06 +0000 (13:29 +0800)
utils/training/gen_k_mixture_model.cpp

index f89dce4..3e83945 100644 (file)
@@ -131,7 +131,8 @@ bool read_document(PhraseLargeTable * phrases, FILE * document,
     return true;
 }
 
-static void train_word_pair(KMixtureModelSingleGram * single_gram,
+static void train_word_pair(HashofUnigram hash_of_unigram,
+                            KMixtureModelSingleGram * single_gram,
                             phrase_token_t token2, guint32 count){
     KMixtureModelArrayItem array_item;
 
@@ -143,8 +144,24 @@ static void train_word_pair(KMixtureModelSingleGram * single_gram,
         /* Exceeds the maximum occurs allowed of the word or phrase,
          * in a single document.
          */
-        if ( count > maximum_occurs_allowed )
+        if ( count > maximum_occurs_allowed ){
+            gpointer value = NULL;
+            assert( g_hash_table_lookup_extended
+                    (hash_of_unigram, GUINT_TO_POINTER(token2),
+                     NULL, &value) );
+            guint32 freq = GPOINTER_TO_UINT(value);
+            freq -= count;
+            if ( freq > 0 ) {
+                g_hash_table_insert(hash_of_unigram, GUINT_TO_POINTER(token2),
+                                    GUINT_TO_POINTER(freq));
+            } else if ( freq == 0 ) {
+                assert(g_hash_table_steal(hash_of_unigram,
+                                          GUINT_TO_POINTER(token2)));
+            } else {
+                assert(false);
+            }
             return;
+        }
         array_item.m_WC += count;
         /* array_item.m_T += count; the same as m_WC. */
         array_item.m_N_n_0 ++;
@@ -154,8 +171,24 @@ static void train_word_pair(KMixtureModelSingleGram * single_gram,
         assert(single_gram->set_array_item(token2, array_item));
     } else { /* item doesn't exist. */
         /* the same as above. */
-        if ( count > g_maximum_occurs )
+        if ( count > g_maximum_occurs ){
+            gpointer value = NULL;
+            assert( g_hash_table_lookup_extended
+                    (hash_of_unigram, GUINT_TO_POINTER(token2),
+                     NULL, &value) );
+            guint32 freq = GPOINTER_TO_UINT(value);
+            freq -= count;
+            if ( freq > 0 ) {
+                g_hash_table_insert(hash_of_unigram, GUINT_TO_POINTER(token2),
+                                    GUINT_TO_POINTER(freq));
+            } else if ( freq == 0 ) {
+                assert(g_hash_table_steal(hash_of_unigram,
+                                          GUINT_TO_POINTER(token2)));
+            } else {
+                assert(false);
+            }
             return;
+        }
         memset(&array_item, 0, sizeof(KMixtureModelArrayItem));
         array_item.m_WC = count;
         /* array_item.m_T = count; the same as m_WC. */
@@ -173,7 +206,8 @@ static void train_word_pair(KMixtureModelSingleGram * single_gram,
     single_gram->set_array_header(array_header);
 }
 
-bool train_single_gram(HashofDocument hash_of_document,
+bool train_single_gram(HashofUnigram hash_of_unigram,
+                       HashofDocument hash_of_document,
                        KMixtureModelSingleGram * single_gram,
                        phrase_token_t token1,
                        guint32 & delta){
@@ -197,7 +231,7 @@ bool train_single_gram(HashofDocument hash_of_document,
     while (g_hash_table_iter_next(&iter, &key, &value)) {
         phrase_token_t token2 = GPOINTER_TO_UINT(key);
         guint32 count = GPOINTER_TO_UINT(value);
-        train_word_pair(single_gram, token2, count);
+        train_word_pair(hash_of_unigram, single_gram, token2, count);
     }
 
     assert(single_gram->get_array_header(array_header));
@@ -205,7 +239,8 @@ bool train_single_gram(HashofDocument hash_of_document,
     return true;
 }
 
-static bool train_second_word(KMixtureModelBigram * bigram,
+static bool train_second_word(HashofUnigram hash_of_unigram,
+                              KMixtureModelBigram * bigram,
                               HashofDocument hash_of_document,
                               phrase_token_t token1){
     guint32 delta = 0;
@@ -214,7 +249,8 @@ static bool train_second_word(KMixtureModelBigram * bigram,
     bool exists = bigram->load(token1, single_gram);
     if ( !exists )
         single_gram = new KMixtureModelSingleGram;
-    train_single_gram(hash_of_document, single_gram, token1, delta);
+    train_single_gram(hash_of_unigram, hash_of_document,
+                      single_gram, token1, delta);
 
     if ( 0 == delta ){ /* Please consider maximum occurs allowed. */
         delete single_gram;
@@ -337,7 +373,8 @@ int main(int argc, char * argv[]){
         g_hash_table_iter_init(&iter, hash_of_document);
         while (g_hash_table_iter_next(&iter, &key, &value)) {
             phrase_token_t token1 = GPOINTER_TO_UINT(key);
-            train_second_word(&bigram, hash_of_document, token1);
+            train_second_word(hash_of_unigram, &bigram,
+                              hash_of_document, token1);
         }
 
         KMixtureModelMagicHeader magic_header;