Added initial version of Generic Text Classifier
[platform/core/api/generic-text-classifier.git] / src / gc_categorization.cpp
1 /*
2  * Copyright (c) 2011 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the License);
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an AS IS BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <mutex>
18
19 #include "gc_persistent_storage.h"
20 #include "gc_data_preprocessing.h"
21 #include "gc_categorization.h"
22
23 /*This set of MACROS has to be called in this sequence only, first call START Macro then END Macro */
24 #define NONINTERRUPT_BLOCK_START {int l1234Var = 0; mtx.lock();
25 #define NONINTERRUPT_BLOCK_END mtx.unlock(); g1234Var = l1234Var; if (cancelFlag) { resetCancelFlag(); return res; }}
26
27 /*
28 * This MACROS has to be called when we want to exit from function based on cancel flag, This must not be
29 * called in between NONINTERRUPT_BLOCK_START and NONINTERRUPT_BLOCK_END macro
30 */
31 #define UNLOCKED_EXIT_POINT int l1234Var = 0; if (cancelFlag) {g1234Var = l1234Var; resetCancelFlag(); return res; }
32
33 mutex mtx;
34
35 classificationResult::classificationResult()
36 {
37         m_classification_api_sucess_state = false;
38         m_classification_api_error = GC_ERROR_CODE_UNKNOWN;
39 }
40
41 vector<float> Classifier::computeMean(PersistentStorage& ps, const vector<string>& words)
42 {
43         vector<vector<float>> embeds = ps.getWordEmbedding(words);
44
45         int rows = embeds.size();
46         vector<float> result(W2VDIM);
47
48         if (embeds.size() == 0)
49                 return result;
50
51         for (int i = 0; i < W2VDIM; i++) {
52                 for (int j = 0; j < rows; j++)
53                         result[i] += embeds[j][i];
54                 result[i] = result[i] / rows;
55         }
56
57         return result;
58 }
59
60 float Classifier::computeCosine(const vector<float>& A, const vector<float>& B)
61 {
62         long double dot = 0.0, denom_a = 0.0, denom_b = 0.0;
63
64         for (unsigned int i = 0; i < A.size(); ++i) {
65                 dot += A[i] * B[i] ;
66                 denom_a += A[i] * A[i] ;
67                 denom_b += B[i] * B[i] ;
68         }
69
70         return dot / sqrt(denom_a * denom_b) ;
71 }
72
73 static void computeCategoryProfile(CategoryProfile& usrProf, vector<string>& contentWords, PersistentStorage& ps)
74 {
75         vector<float> w2vsum(W2VDIM);
76         vector<vector<float>> wordVectors = ps.getWordEmbedding(contentWords);
77         int numWords = wordVectors.size(), curWords = usrProf.wordCount;
78
79         for(int i = 0; i < W2VDIM; i++)
80                 for(int j = 0; j < numWords; j++)
81                         w2vsum[i] += wordVectors[j][i];
82
83         for(int j = 0; j < W2VDIM; j++)
84                 usrProf.profile[j] = (usrProf.profile[j] * curWords + w2vsum[j]) / (curWords + numWords);
85         usrProf.wordCount += numWords;
86 }
87
88 int Classifier::addCategory(const string userID, const string category)
89 {
90         PersistentStorage ps;
91         int retVal = ps.addCategoryProfile(userID, category);
92         return retVal;
93 }
94
95 int Classifier::deleteCategory(const string userID, const string category)
96 {
97         PersistentStorage ps;
98         int retVal = ps.deleteCategoryProfile(userID, category);
99         return retVal;
100 }
101
102 int Classifier::updateCategory(const string userID, const string category, const char* text)
103 {
104         PersistentStorage ps;
105
106         pair<vector<string>, vector<string>> textWords;
107         vector<string> keyWords, contentWords;
108
109         textWords = extractKeywords(text, ps);
110         keyWords = textWords.first;
111         contentWords = textWords.second;
112
113         CategoryProfile categoryProfile;
114
115         categoryProfile = ps.getCategoryProfile(category);
116         computeCategoryProfile(categoryProfile, contentWords, ps);
117         categoryProfile.categoryName = category;
118         int retVal = ps.updateCategoryProfile(userID, categoryProfile);
119         return retVal;
120 }
121
122 classificationResult Classifier::getCategoryInference(const string userID, const char* inputText)
123 {
124         PersistentStorage ps;
125         classificationResult res;
126         vector<pair<float, string>> result;
127         vector<string> words, keyWords, contentWords;
128
129         pair<vector<string>, vector<string>> textWords = extractKeywords(inputText, ps);
130         keyWords = textWords.first;
131         contentWords = textWords.second;
132
133         vector<float> mean = computeMean(ps, contentWords);
134         vector<pair<string, vector<float>>> tempCatProf;
135         tempCatProf = ps.getAllUserCategoryProfile(userID);
136         tempCatProf.insert(tempCatProf.end(), catProf.begin(), catProf.end());
137
138         for (unsigned int i = 0; i < tempCatProf.size(); i++) {
139                 float cos_val = 0.0;
140
141                 if (contentWords.size() > 0)
142                         cos_val = computeCosine(mean, tempCatProf[i].second);
143                 else
144                         cos_val = 0.0;
145                 result.push_back(make_pair(cos_val, tempCatProf[i].first));
146         }
147         sort(result.begin(), result.end());
148         reverse(result.begin(), result.end());
149         res.m_classification_api_sucess_state = true;
150
151         for (unsigned int i = 0; i < result.size(); i++)
152                 result[i].first -= result[tempCatProf.size() - 1].first;
153
154         float sftmxSum = result[0].first + result[1].first + result[2].first;
155         for (unsigned int i = 0; i < result.size(); i++) {
156                 res.category_score_vec.push_back(result[i].first/sftmxSum);
157                 res.category_name_vec.push_back(result[i].second);
158         }
159         return res;
160 }
161
162 Classifier::Classifier()
163 {
164         cancelFlag = false;
165         PersistentStorage ps;
166         catProf = ps.getAllCategoryProfile();
167 }
168
169 void Classifier::cancelThread()
170 {
171         cancelFlag = true;
172 }
173
174 void Classifier::resetCancelFlag()
175 {
176         cancelFlag = false;
177 }