Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / extensions / linear / linearscript.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3
4 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
5 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
6
7 #include <istream>
8 #include <sstream>
9 #include <string>
10 #include <vector>
11
12 #include <fst/compat.h>
13 #include <fst/extensions/linear/linear-fst-data-builder.h>
14 #include <fst/extensions/linear/linear-fst.h>
15 #include <fstream>
16 #include <fst/symbol-table.h>
17 #include <fst/script/arg-packs.h>
18 #include <fst/script/script-impl.h>
19
20 DECLARE_string(delimiter);
21 DECLARE_string(empty_symbol);
22 DECLARE_string(start_symbol);
23 DECLARE_string(end_symbol);
24 DECLARE_bool(classifier);
25
26 namespace fst {
27 namespace script {
28 typedef args::Package<const string &, const string &, const string &, char **,
29                       int, const string &, const string &, const string &,
30                       const string &> LinearCompileArgs;
31
32 bool ValidateDelimiter();
33 bool ValidateEmptySymbol();
34
35 // Returns the proper label given the symbol. For symbols other than
36 // `FLAGS_start_symbol` or `FLAGS_end_symbol`, looks up the symbol
37 // table to decide the label. Depending on whether
38 // `FLAGS_start_symbol` and `FLAGS_end_symbol` are identical, it
39 // either returns `kNoLabel` for later processing or decides the label
40 // right away.
41 template <class Arc>
42 inline typename Arc::Label LookUp(const string &str, SymbolTable *syms) {
43   if (str == FLAGS_start_symbol)
44     return str == FLAGS_end_symbol ? kNoLabel
45                                    : LinearFstData<Arc>::kStartOfSentence;
46   else if (str == FLAGS_end_symbol)
47     return LinearFstData<Arc>::kEndOfSentence;
48   else
49     return syms->AddSymbol(str);
50 }
51
52 // Splits `str` with `delim` as the delimiter and stores the labels in
53 // `output`.
54 template <class Arc>
55 void SplitAndPush(const string &str, const char delim, SymbolTable *syms,
56                   std::vector<typename Arc::Label> *output) {
57   if (str == FLAGS_empty_symbol) return;
58   std::istringstream strm(str);
59   string buf;
60   while (std::getline(strm, buf, delim))
61     output->push_back(LookUp<Arc>(buf, syms));
62 }
63
64 // Like `std::replace_copy` but returns the number of modifications
65 template <class InputIterator, class OutputIterator, class T>
66 size_t ReplaceCopy(InputIterator first, InputIterator last,
67                    OutputIterator result, const T &old_value,
68                    const T &new_value) {
69   size_t changes = 0;
70   while (first != last) {
71     if (*first == old_value) {
72       *result = new_value;
73       ++changes;
74     } else {
75       *result = *first;
76     }
77     ++first;
78     ++result;
79   }
80   return changes;
81 }
82
83 template <class Arc>
84 bool GetVocabRecord(const string &vocab, std::istream &strm,  // NOLINT
85                     SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
86                     typename Arc::Label *word,
87                     std::vector<typename Arc::Label> *feature_labels,
88                     std::vector<typename Arc::Label> *possible_labels,
89                     size_t *num_line);
90
91 template <class Arc>
92 bool GetModelRecord(const string &model, std::istream &strm,  // NOLINT
93                     SymbolTable *fsyms, SymbolTable *osyms,
94                     std::vector<typename Arc::Label> *input_labels,
95                     std::vector<typename Arc::Label> *output_labels,
96                     typename Arc::Weight *weight, size_t *num_line);
97
98 // Reads in vocabulary file. Each line is in the following format
99 //
100 //   word <whitespace> features [ <whitespace> possible output ]
101 //
102 // where features and possible output are `FLAGS_delimiter`-delimited lists of
103 // tokens
104 template <class Arc>
105 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
106               SymbolTable *osyms, LinearFstDataBuilder<Arc> *builder) {
107   std::ifstream in(vocab);
108   if (!in) LOG(FATAL) << "Can't open file: " << vocab;
109   size_t num_line = 0, num_added = 0;
110   std::vector<string> fields;
111   std::vector<typename Arc::Label> feature_labels, possible_labels;
112   typename Arc::Label word;
113   while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
114                              &feature_labels, &possible_labels, &num_line)) {
115     if (word == kNoLabel) {
116       LOG(WARNING) << "Ignored: boundary word: " << fields[0];
117       continue;
118     }
119     if (possible_labels.empty())
120       num_added += builder->AddWord(word, feature_labels);
121     else
122       num_added += builder->AddWord(word, feature_labels, possible_labels);
123   }
124   VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
125           << vocab;
126 }
127
128 template <class Arc>
129 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
130               SymbolTable *osyms,
131               LinearClassifierFstDataBuilder<Arc> *builder) {
132   std::ifstream in(vocab);
133   if (!in) LOG(FATAL) << "Can't open file: " << vocab;
134   size_t num_line = 0, num_added = 0;
135   std::vector<string> fields;
136   std::vector<typename Arc::Label> feature_labels, possible_labels;
137   typename Arc::Label word;
138   while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
139                              &feature_labels, &possible_labels, &num_line)) {
140     if (!possible_labels.empty())
141       LOG(FATAL)
142           << "Classifier vocabulary should not have possible output constraint";
143     if (word == kNoLabel) {
144       LOG(WARNING) << "Ignored: boundary word: " << fields[0];
145       continue;
146     }
147     num_added += builder->AddWord(word, feature_labels);
148   }
149   VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
150           << vocab;
151 }
152
153 // Reads in model file. The first line is an integer designating the
154 // size of future window in the input sequences. After this, each line
155 // is in the following format
156 //
157 //   input sequence <whitespace> output sequence <whitespace> weight
158 //
159 // input sequence is a `FLAGS_delimiter`-delimited sequence of feature
160 // labels (see `AddVocab()`) . output sequence is a
161 // `FLAGS_delimiter`-delimited sequence of output labels where the
162 // last label is the output of the feature position before the history
163 // boundary.
164 template <class Arc>
165 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
166               LinearFstDataBuilder<Arc> *builder) {
167   std::ifstream in(model);
168   if (!in) LOG(FATAL) << "Can't open file: " << model;
169   string line;
170   std::getline(in, line);
171   if (!in) LOG(FATAL) << "Empty file: " << model;
172   size_t future_size;
173   {
174     std::istringstream strm(line);
175     strm >> future_size;
176     if (!strm) LOG(FATAL) << "Can't read future size: " << model;
177   }
178   size_t num_line = 1, num_added = 0;
179   const int group = builder->AddGroup(future_size);
180   VLOG(1) << "Group " << group << ": from " << model << "; future size is "
181           << future_size << ".";
182   // Add the rest of lines as a single feature group
183   std::vector<string> fields;
184   std::vector<typename Arc::Label> input_labels, output_labels;
185   typename Arc::Weight weight;
186   while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
187                              &output_labels, &weight, &num_line)) {
188     if (output_labels.empty())
189       LOG(FATAL) << "Empty output sequence in source " << model << ", line "
190                  << num_line;
191
192     const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
193                                          LinearFstData<Arc>::kEndOfSentence};
194
195     std::vector<typename Arc::Label> copy_input(input_labels.size()),
196         copy_output(output_labels.size());
197     for (int i = 0; i < 2; ++i) {
198       for (int j = 0; j < 2; ++j) {
199         size_t num_input_changes =
200             ReplaceCopy(input_labels.begin(), input_labels.end(),
201                         copy_input.begin(), kNoLabel, marks[i]);
202         size_t num_output_changes =
203             ReplaceCopy(output_labels.begin(), output_labels.end(),
204                         copy_output.begin(), kNoLabel, marks[j]);
205         if ((num_input_changes > 0 || i == 0) &&
206             (num_output_changes > 0 || j == 0))
207           num_added +=
208               builder->AddWeight(group, copy_input, copy_output, weight);
209       }
210     }
211   }
212   VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
213           << num_line << " lines.";
214 }
215
216 template <class Arc>
217 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
218               LinearClassifierFstDataBuilder<Arc> *builder) {
219   std::ifstream in(model);
220   if (!in) LOG(FATAL) << "Can't open file: " << model;
221   string line;
222   std::getline(in, line);
223   if (!in) LOG(FATAL) << "Empty file: " << model;
224   size_t future_size;
225   {
226     std::istringstream strm(line);
227     strm >> future_size;
228     if (!strm) LOG(FATAL) << "Can't read future size: " << model;
229   }
230   if (future_size != 0)
231     LOG(FATAL) << "Classifier model must have future size = 0; got "
232                << future_size << " from " << model;
233   size_t num_line = 1, num_added = 0;
234   const int group = builder->AddGroup();
235   VLOG(1) << "Group " << group << ": from " << model << "; future size is "
236           << future_size << ".";
237   // Add the rest of lines as a single feature group
238   std::vector<string> fields;
239   std::vector<typename Arc::Label> input_labels, output_labels;
240   typename Arc::Weight weight;
241   while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
242                              &output_labels, &weight, &num_line)) {
243     if (output_labels.size() != 1)
244       LOG(FATAL) << "Output not a single label in source " << model << ", line "
245                  << num_line;
246
247     const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
248                                          LinearFstData<Arc>::kEndOfSentence};
249
250     typename Arc::Label pred = output_labels[0];
251
252     std::vector<typename Arc::Label> copy_input(input_labels.size());
253     for (int i = 0; i < 2; ++i) {
254       size_t num_input_changes =
255           ReplaceCopy(input_labels.begin(), input_labels.end(),
256                       copy_input.begin(), kNoLabel, marks[i]);
257       if (num_input_changes > 0 || i == 0)
258         num_added += builder->AddWeight(group, copy_input, pred, weight);
259     }
260   }
261   VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
262           << num_line << " lines.";
263 }
264
265 void SplitByWhitespace(const string &str, std::vector<string> *out);
266 int ScanNumClasses(char **models, int models_length);
267
268 template <class Arc>
269 void LinearCompileTpl(LinearCompileArgs *args) {
270   const string &epsilon_symbol = args->arg1;
271   const string &unknown_symbol = args->arg2;
272   const string &vocab = args->arg3;
273   char **models = args->arg4;
274   const int models_length = args->arg5;
275   const string &out = args->arg6;
276   const string &save_isymbols = args->arg7;
277   const string &save_fsymbols = args->arg8;
278   const string &save_osymbols = args->arg9;
279
280   SymbolTable isyms,  // input (e.g. word tokens)
281       osyms,          // output (e.g. tags)
282       fsyms;          // feature (e.g. word identity, suffix, etc.)
283   isyms.AddSymbol(epsilon_symbol);
284   osyms.AddSymbol(epsilon_symbol);
285   fsyms.AddSymbol(epsilon_symbol);
286   isyms.AddSymbol(unknown_symbol);
287
288   VLOG(1) << "start-of-sentence label is "
289           << LinearFstData<Arc>::kStartOfSentence;
290   VLOG(1) << "end-of-sentence label is " << LinearFstData<Arc>::kEndOfSentence;
291
292   if (FLAGS_classifier) {
293     int num_classes = ScanNumClasses(models, models_length);
294     LinearClassifierFstDataBuilder<Arc> builder(num_classes, &isyms, &fsyms,
295                                                 &osyms);
296
297     AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
298     for (int i = 0; i < models_length; ++i)
299       AddModel(models[i], &fsyms, &osyms, &builder);
300
301     LinearClassifierFst<Arc> fst(builder.Dump(), num_classes, &isyms, &osyms);
302     fst.Write(out);
303   } else {
304     LinearFstDataBuilder<Arc> builder(&isyms, &fsyms, &osyms);
305
306     AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
307     for (int i = 0; i < models_length; ++i)
308       AddModel(models[i], &fsyms, &osyms, &builder);
309
310     LinearTaggerFst<Arc> fst(builder.Dump(), &isyms, &osyms);
311     fst.Write(out);
312   }
313
314   if (!save_isymbols.empty()) isyms.WriteText(save_isymbols);
315   if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols);
316   if (!save_osymbols.empty()) osyms.WriteText(save_osymbols);
317 }
318
319 void LinearCompile(const string &arc_type, const string &epsilon_symbol,
320                    const string &unknown_symbol, const string &vocab,
321                    char **models, int models_len, const string &out,
322                    const string &save_isymbols, const string &save_fsymbols,
323                    const string &save_osymbols);
324
325 template <class Arc>
326 bool GetVocabRecord(const string &vocab, std::istream &strm,  // NOLINT
327                     SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
328                     typename Arc::Label *word,
329                     std::vector<typename Arc::Label> *feature_labels,
330                     std::vector<typename Arc::Label> *possible_labels,
331                     size_t *num_line) {
332   string line;
333   if (!std::getline(strm, line)) return false;
334   ++(*num_line);
335
336   std::vector<string> fields;
337   SplitByWhitespace(line, &fields);
338   if (fields.size() != 3)
339     LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line "
340                << num_line;
341
342   feature_labels->clear();
343   possible_labels->clear();
344
345   *word = LookUp<Arc>(fields[0], isyms);
346
347   const char delim = FLAGS_delimiter[0];
348   SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
349   SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
350
351   return true;
352 }
353
354 template <class Arc>
355 bool GetModelRecord(const string &model, std::istream &strm,  // NOLINT
356                     SymbolTable *fsyms, SymbolTable *osyms,
357                     std::vector<typename Arc::Label> *input_labels,
358                     std::vector<typename Arc::Label> *output_labels,
359                     typename Arc::Weight *weight, size_t *num_line) {
360   string line;
361   if (!std::getline(strm, line)) return false;
362   ++(*num_line);
363
364   std::vector<string> fields;
365   SplitByWhitespace(line, &fields);
366   if (fields.size() != 3)
367     LOG(FATAL) << "Wrong number of fields in source " << model << ", line "
368                << num_line;
369
370   input_labels->clear();
371   output_labels->clear();
372
373   const char delim = FLAGS_delimiter[0];
374   SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
375   SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
376
377   *weight = StrToWeight<typename Arc::Weight>(fields[2], model, *num_line);
378
379   GuessStartOrEnd<Arc>(input_labels, kNoLabel);
380   GuessStartOrEnd<Arc>(output_labels, kNoLabel);
381
382   return true;
383 }
384 }  // namespace script
385 }  // namespace fst
386
387 #define REGISTER_FST_LINEAR_OPERATIONS(Arc) \
388   REGISTER_FST_OPERATION(LinearCompileTpl, Arc, LinearCompileArgs);
389
390 #endif  // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_