write gen k mixture model in progress
authorPeng Wu <alexepico@gmail.com>
Tue, 10 May 2011 05:31:36 +0000 (13:31 +0800)
committerPeng Wu <alexepico@gmail.com>
Tue, 10 May 2011 05:31:36 +0000 (13:31 +0800)
utils/training/gen_k_mixture_model.cpp

index 7d4e3ed..ba0936f 100644 (file)
 
 #include <glib.h>
 #include "pinyin.h"
+#include "k_mixture_model.h"
 
 typedef GHashTable * HashofWordPair;
 typedef GHashTable * HashofSecondWord;
 
 /* Hash token of Hash token of word count. */
-HashofWordPair g_hash_of_document = NULL;
-PhraseLargeTable * g_phrases = NULL;
+static HashofWordPair g_hash_of_document = NULL;
+static PhraseLargeTable * g_phrases = NULL;
+static KMixtureModelBigram * g_bigram = NULL;
+static guint32 g_maximum_occurs = 20;
+static parameter_t g_maximum_increase_rates = 3.;
 
 void print_help(){
     printf("gen_k_mixture_model [--skip-pi-gram-training]\n");
@@ -99,6 +103,110 @@ bool convert_document_to_hash(FILE * document){
     return true;
 }
 
+static void train_word_pair(gpointer key, gpointer value,
+                            gpointer user_data){
+    phrase_token_t token = GPOINTER_TO_UINT(key);
+    guint32 count = GPOINTER_TO_UINT(value);
+    KMixtureModelSingleGram * single_gram =
+        (KMixtureModelSingleGram *)user_data;
+    KMixtureModelArrayItem array_item;
+    guint32 delta = 0;
+
+    bool exists = single_gram->get_array_item(token, array_item);
+    if ( exists ) {
+        guint32 maximum_occurs_allowed = std_lite::max
+            (g_maximum_occurs,
+             (guint32)ceil(array_item.m_Mr * g_maximum_increase_rates));
+        /* Exceeds the maximum occurs allowed of the word or phrase,
+         * in a single document.
+         */
+        if ( count > maximum_occurs_allowed )
+            return;
+        array_item.m_WC += count;
+        /* array_item.m_T += count; the same as m_WC. */
+        array_item.m_N_n_0 ++;
+        if ( 1 == count )
+            array_item.m_n_1 ++;
+        array_item.m_Mr = std_lite::max(array_item.m_Mr, count);
+        delta = count;
+    } else { /* item doesn't exist. */
+        /* the same as above. */
+        if ( count > g_maximum_occurs )
+            return;
+        memset(&array_item, 0, sizeof(KMixtureModelArrayItem));
+        array_item.m_WC = count;
+        /* array_item.m_T = count; the same as m_WC. */
+        array_item.m_N_n_0 = 1;
+        if ( 1 == count )
+            array_item.m_n_1 = 1;
+        array_item.m_Mr = count;
+        delta = count;
+    }
+    /* save delta in the array header. */
+    KMixtureModelArrayHeader array_header;
+    single_gram->get_array_header(array_header);
+    array_header.m_WC += delta;
+    single_gram->set_array_header(array_header);
+}
+
+bool train_single_gram(phrase_token_t token,
+                       KMixtureModelSingleGram * single_gram,
+                       guint32 & delta){
+    assert(NULL != single_gram);
+    delta = 0; /* delta in WC of single_gram. */
+    KMixtureModelArrayHeader array_header;
+    assert(single_gram->get_array_header(array_header));
+    guint32 saved_array_header_WC = array_header.m_WC;
+
+    HashofSecondWord hash_of_second_word = NULL;
+    gpointer value = NULL;
+    assert(g_hash_table_lookup_extended
+           (g_hash_of_document, GUINT_TO_POINTER(token),
+            NULL, &value));
+    hash_of_second_word = (HashofSecondWord) value;
+    assert(NULL != hash_of_second_word);
+
+    g_hash_table_foreach(hash_of_second_word, train_word_pair, single_gram);
+
+    assert(single_gram->get_array_header(array_header));
+    delta = array_header.m_WC - saved_array_header_WC;
+    return true;
+}
+
+static void train_single_gram_wrapper(gpointer key, gpointer value,
+                                      gpointer user_data){
+    phrase_token_t token = GPOINTER_TO_UINT(key);
+    guint32 delta = 0;
+
+    KMixtureModelSingleGram * single_gram = NULL;
+    bool exists = g_bigram->load(token, single_gram);
+    if ( exists ){
+        train_single_gram(token, single_gram, delta);
+    } else { /* item doesn't exist. */
+        single_gram = new KMixtureModelSingleGram;
+        train_single_gram(token, single_gram, delta);
+    }
+
+    KMixtureModelMagicHeader magic_header;
+    assert(g_bigram->get_magic_header(magic_header));
+    if ( magic_header.m_WC + delta < magic_header.m_WC ){
+        fprintf(stderr, "the m_WC integer in magic header overflows.\n");
+        return;
+    }
+    magic_header.m_WC += delta;
+    magic_header.m_N ++;
+    assert(g_bigram->set_magic_header(magic_header));
+
+    /* save the single gram. */
+    assert(g_bigram->store(token, single_gram));
+    delete single_gram;
+}
+
+bool train_document(){
+    g_hash_table_foreach(g_hash_of_document, train_single_gram_wrapper, NULL);
+    return true;
+}
+
 int main(int argc, char * argv[]){
     g_hash_of_document = g_hash_table_new_full
         (g_int_hash, g_int_equal, NULL, (GDestroyNotify)g_hash_table_unref);