2 * Copyright (c) 2011 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the License);
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an AS IS BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include "gc_persistent_storage.h"
20 #include "gc_data_preprocessing.h"
21 #include "gc_categorization.h"
23 /*This set of MACROS has to be called in this sequence only, first call START Macro then END Macro */
24 #define NONINTERRUPT_BLOCK_START {int l1234Var = 0; mtx.lock();
25 #define NONINTERRUPT_BLOCK_END mtx.unlock(); g1234Var = l1234Var; if (cancelFlag) { resetCancelFlag(); return res; }}
28 * This MACROS has to be called when we want to exit from function based on cancel flag, This must not be
29 * called in between NONINTERRUPT_BLOCK_START and NONINTERRUPT_BLOCK_END macro
31 #define UNLOCKED_EXIT_POINT int l1234Var = 0; if (cancelFlag) {g1234Var = l1234Var; resetCancelFlag(); return res; }
35 classificationResult::classificationResult()
37 m_classification_api_sucess_state = false;
38 m_classification_api_error = GC_ERROR_CODE_UNKNOWN;
41 vector<float> Classifier::computeMean(PersistentStorage& ps, const vector<string>& words)
43 vector<vector<float>> embeds = ps.getWordEmbedding(words);
45 int rows = embeds.size();
46 vector<float> result(W2VDIM);
48 if (embeds.size() == 0)
51 for (int i = 0; i < W2VDIM; i++) {
52 for (int j = 0; j < rows; j++)
53 result[i] += embeds[j][i];
54 result[i] = result[i] / rows;
60 float Classifier::computeCosine(const vector<float>& A, const vector<float>& B)
62 long double dot = 0.0, denom_a = 0.0, denom_b = 0.0;
64 for (unsigned int i = 0; i < A.size(); ++i) {
66 denom_a += A[i] * A[i] ;
67 denom_b += B[i] * B[i] ;
70 return dot / sqrt(denom_a * denom_b) ;
73 static void computeCategoryProfile(CategoryProfile& usrProf, vector<string>& contentWords, PersistentStorage& ps)
75 vector<float> w2vsum(W2VDIM);
76 vector<vector<float>> wordVectors = ps.getWordEmbedding(contentWords);
77 int numWords = wordVectors.size(), curWords = usrProf.wordCount;
79 for(int i = 0; i < W2VDIM; i++)
80 for(int j = 0; j < numWords; j++)
81 w2vsum[i] += wordVectors[j][i];
83 for(int j = 0; j < W2VDIM; j++)
84 usrProf.profile[j] = (usrProf.profile[j] * curWords + w2vsum[j]) / (curWords + numWords);
85 usrProf.wordCount += numWords;
88 int Classifier::addCategory(const string userID, const string category)
91 int retVal = ps.addCategoryProfile(userID, category);
95 int Classifier::deleteCategory(const string userID, const string category)
98 int retVal = ps.deleteCategoryProfile(userID, category);
102 int Classifier::updateCategory(const string userID, const string category, const char* text)
104 PersistentStorage ps;
106 pair<vector<string>, vector<string>> textWords;
107 vector<string> keyWords, contentWords;
109 textWords = extractKeywords(text, ps);
110 keyWords = textWords.first;
111 contentWords = textWords.second;
113 CategoryProfile categoryProfile;
115 categoryProfile = ps.getCategoryProfile(category);
116 computeCategoryProfile(categoryProfile, contentWords, ps);
117 categoryProfile.categoryName = category;
118 int retVal = ps.updateCategoryProfile(userID, categoryProfile);
122 classificationResult Classifier::getCategoryInference(const string userID, const char* inputText)
124 PersistentStorage ps;
125 classificationResult res;
126 vector<pair<float, string>> result;
127 vector<string> words, keyWords, contentWords;
129 pair<vector<string>, vector<string>> textWords = extractKeywords(inputText, ps);
130 keyWords = textWords.first;
131 contentWords = textWords.second;
133 vector<float> mean = computeMean(ps, contentWords);
134 vector<pair<string, vector<float>>> tempCatProf;
135 tempCatProf = ps.getAllUserCategoryProfile(userID);
136 tempCatProf.insert(tempCatProf.end(), catProf.begin(), catProf.end());
138 for (unsigned int i = 0; i < tempCatProf.size(); i++) {
141 if (contentWords.size() > 0)
142 cos_val = computeCosine(mean, tempCatProf[i].second);
145 result.push_back(make_pair(cos_val, tempCatProf[i].first));
147 sort(result.begin(), result.end());
148 reverse(result.begin(), result.end());
149 res.m_classification_api_sucess_state = true;
151 for (unsigned int i = 0; i < result.size(); i++)
152 result[i].first -= result[tempCatProf.size() - 1].first;
154 float sftmxSum = result[0].first + result[1].first + result[2].first;
155 for (unsigned int i = 0; i < result.size(); i++) {
156 res.category_score_vec.push_back(result[i].first/sftmxSum);
157 res.category_name_vec.push_back(result[i].second);
162 Classifier::Classifier()
165 PersistentStorage ps;
166 catProf = ps.getAllCategoryProfile();
169 void Classifier::cancelThread()
174 void Classifier::resetCancelFlag()