add unigram to gen k mixture model
authorPeng Wu <alexepico@gmail.com>
Tue, 7 Jun 2011 08:10:35 +0000 (16:10 +0800)
committerPeng Wu <alexepico@gmail.com>
Tue, 7 Jun 2011 08:10:35 +0000 (16:10 +0800)
utils/training/gen_k_mixture_model.cpp
utils/training/k_mixture_model.h

index 8ab7b2d36e9dc73e5fcf120846ada4e4a2f78edb..c26ac498d9332eaac6a11dec679da17c32d0d7ba 100644 (file)
@@ -27,6 +27,7 @@
 
 typedef GHashTable * HashofDocument;
 typedef GHashTable * HashofSecondWord;
+typedef GHashTable * HashofUnigram;
 
 /* Hash token of Hash token of word count. */
 static guint32 g_maximum_occurs = 20;
@@ -44,7 +45,8 @@ void print_help(){
 
 
 bool read_document(PhraseLargeTable * phrases, FILE * document,
-                   HashofDocument hash_of_document){
+                   HashofDocument hash_of_document,
+                   HashofUnigram hash_of_unigram){
     char * linebuf = NULL;
     size_t size = 0;
     phrase_token_t last_token, cur_token = last_token = 0;
@@ -76,6 +78,20 @@ bool read_document(PhraseLargeTable * phrases, FILE * document,
         if ( null_token == cur_token )
             continue;
 
+        gpointer value = NULL;
+        gboolean lookup_result = g_hash_table_lookup_extended
+            (hash_of_unigram, GUINT_TO_POINTER(cur_token),
+             NULL, &value);
+        if ( !lookup_result ){
+            g_hash_table_insert(hash_of_unigram, GUINT_TO_POINTER(cur_token),
+                                GUINT_TO_POINTER(1));
+        } else {
+            guint32 freq = GPOINTER_TO_UINT(value);
+            freq ++;
+            g_hash_table_insert(hash_of_unigram, GUINT_TO_POINTER(cur_token),
+                                GUINT_TO_POINTER(freq));
+        }
+
         /* skip pi-gram training. */
         if ( null_token == last_token ){
             if ( !g_train_pi_gram )
@@ -84,9 +100,8 @@ bool read_document(PhraseLargeTable * phrases, FILE * document,
         }
 
         /* remember the (last_token, cur_token) word pair. */
-        gpointer value = NULL;
         HashofSecondWord hash_of_second_word = NULL;
-        gboolean lookup_result = g_hash_table_lookup_extended
+        lookup_result = g_hash_table_lookup_extended
             (hash_of_document, GUINT_TO_POINTER(last_token),
              NULL, &value);
         if ( !lookup_result ){
@@ -227,6 +242,33 @@ static bool train_second_word(KMixtureModelBigram * bigram,
     return true;
 }
 
+/* Note: this method is a post-processing method, run this last. */
+static bool post_processing_unigram(KMixtureModelBigram * bigram,
+                                    HashofUnigram hash_of_unigram){
+    GHashTableIter iter;
+    gpointer key, value;
+    guint32 total_freq = 0;
+
+    g_hash_table_iter_init(&iter, hash_of_unigram);
+    while (g_hash_table_iter_next(&iter, &key, &value)){
+        guint32 token = GPOINTER_TO_UINT(key);
+        guint32 freq = GPOINTER_TO_UINT(value);
+        KMixtureModelArrayHeader array_header;
+        memset(&array_header, 0, sizeof(KMixtureModelArrayHeader));
+        bool result = bigram->get_array_header(token, array_header);
+        array_header.m_freq += freq;
+        total_freq += freq;
+        bigram->set_array_header(token, array_header);
+    }
+
+    KMixtureModelMagicHeader magic_header;
+    assert(bigram->get_magic_header(magic_header));
+    magic_header.m_total_freq += total_freq;
+    assert(bigram->set_magic_header(magic_header));
+
+    return true;
+}
+
 int main(int argc, char * argv[]){
     int i = 1;
     const char * k_mixture_model_filename = NULL;
@@ -282,8 +324,11 @@ int main(int argc, char * argv[]){
 
         HashofDocument hash_of_document = g_hash_table_new
             (g_direct_hash, g_direct_equal);
+        HashofUnigram hash_of_unigram = g_hash_table_new
+            (g_direct_hash, g_direct_equal);
 
-        assert(read_document(&phrases, document, hash_of_document));
+        assert(read_document(&phrases, document,
+                             hash_of_document, hash_of_unigram));
         fclose(document);
         document = NULL;
 
@@ -302,6 +347,8 @@ int main(int argc, char * argv[]){
         magic_header.m_N ++;
         assert(bigram.set_magic_header(magic_header));
 
+        post_processing_unigram(&bigram, hash_of_unigram);
+
         /* free resources of g_hash_of_document */
         g_hash_table_iter_init(&iter, hash_of_document);
         while (g_hash_table_iter_next(&iter, &key, &value)) {
@@ -312,6 +359,9 @@ int main(int argc, char * argv[]){
         g_hash_table_unref(hash_of_document);
         hash_of_document = NULL;
 
+        g_hash_table_unref(hash_of_unigram);
+        hash_of_unigram = NULL;
+
         ++i;
     }
 
index f122792cf92a8930cbb268d88fdbc962563cd607..086455f79e98133738b88dc6ebc1863701938aee 100644 (file)
@@ -119,9 +119,10 @@ typedef struct{
 
 typedef struct{
     /* dummy varibles */
-    guint32 dummy[2];
+    guint32 dummy[3];
     /* the freq of uni-gram. see m_total_freq in magic header also. */
     guint32 m_freq;
+    guint32 dummy2[3];
     /* the total number of instances of word W1. */
     guint32 m_WC;
 } KMixtureModelArrayHeader;