9c7f2490f3f1958ed326e7fcce06ed077fd60d31
[platform/upstream/openfst.git] / src / include / fst / extensions / linear / linearscript.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3
4 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
5 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
6
7 #include <istream>
8 #include <sstream>
9 #include <string>
10 #include <vector>
11
12 #include <fst/compat.h>
13 #include <fst/extensions/linear/linear-fst-data-builder.h>
14 #include <fst/extensions/linear/linear-fst.h>
15 #include <fstream>
16 #include <fst/symbol-table.h>
17 #include <fst/script/arg-packs.h>
18 #include <fst/script/script-impl.h>
19
20 DECLARE_string(delimiter);
21 DECLARE_string(empty_symbol);
22 DECLARE_string(start_symbol);
23 DECLARE_string(end_symbol);
24 DECLARE_bool(classifier);
25
26 namespace fst {
27 namespace script {
28 typedef args::Package<const string &, const string &, const string &, char **,
29                       int, const string &, const string &, const string &,
30                       const string &> LinearCompileArgs;
31
32 bool ValidateDelimiter();
33 bool ValidateEmptySymbol();
34
35 // Returns the proper label given the symbol. For symbols other than
36 // `FLAGS_start_symbol` or `FLAGS_end_symbol`, looks up the symbol
37 // table to decide the label. Depending on whether
38 // `FLAGS_start_symbol` and `FLAGS_end_symbol` are identical, it
39 // either returns `kNoLabel` for later processing or decides the label
40 // right away.
41 template <class Arc>
42 inline typename Arc::Label LookUp(const string &str, SymbolTable *syms) {
43   if (str == FLAGS_start_symbol)
44     return str == FLAGS_end_symbol ? kNoLabel
45                                    : LinearFstData<Arc>::kStartOfSentence;
46   else if (str == FLAGS_end_symbol)
47     return LinearFstData<Arc>::kEndOfSentence;
48   else
49     return syms->AddSymbol(str);
50 }
51
52 // Splits `str` with `delim` as the delimiter and stores the labels in
53 // `output`.
54 template <class Arc>
55 void SplitAndPush(const string &str, const char delim, SymbolTable *syms,
56                   std::vector<typename Arc::Label> *output) {
57   if (str == FLAGS_empty_symbol) return;
58   std::istringstream strm(str);
59   string buf;
60   while (std::getline(strm, buf, delim))
61     output->push_back(LookUp<Arc>(buf, syms));
62 }
63
64 // Like `std::replace_copy` but returns the number of modifications
65 template <class InputIterator, class OutputIterator, class T>
66 size_t ReplaceCopy(InputIterator first, InputIterator last,
67                    OutputIterator result, const T &old_value,
68                    const T &new_value) {
69   size_t changes = 0;
70   while (first != last) {
71     if (*first == old_value) {
72       *result = new_value;
73       ++changes;
74     } else {
75       *result = *first;
76     }
77     ++first;
78     ++result;
79   }
80   return changes;
81 }
82
83 template <class Arc>
84 bool GetVocabRecord(const string &vocab, std::istream &strm,  // NOLINT
85                     SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
86                     typename Arc::Label *word,
87                     std::vector<typename Arc::Label> *feature_labels,
88                     std::vector<typename Arc::Label> *possible_labels,
89                     size_t *num_line);
90
91 template <class Arc>
92 bool GetModelRecord(const string &model, std::istream &strm,  // NOLINT
93                     SymbolTable *fsyms, SymbolTable *osyms,
94                     std::vector<typename Arc::Label> *input_labels,
95                     std::vector<typename Arc::Label> *output_labels,
96                     typename Arc::Weight *weight, size_t *num_line);
97
98 // Reads in vocabulary file. Each line is in the following format
99 //
100 //   word <whitespace> features [ <whitespace> possible output ]
101 //
102 // where features and possible output are `FLAGS_delimiter`-delimited lists of
103 // tokens
104 template <class Arc>
105 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
106               SymbolTable *osyms, LinearFstDataBuilder<Arc> *builder) {
107   std::ifstream in(vocab);
108   if (!in) LOG(FATAL) << "Can't open file: " << vocab;
109   size_t num_line = 0, num_added = 0;
110   std::vector<string> fields;
111   std::vector<typename Arc::Label> feature_labels, possible_labels;
112   typename Arc::Label word;
113   while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
114                              &feature_labels, &possible_labels, &num_line)) {
115     if (word == kNoLabel) {
116       LOG(WARNING) << "Ignored: boundary word: " << fields[0];
117       continue;
118     }
119     if (possible_labels.empty())
120       num_added += builder->AddWord(word, feature_labels);
121     else
122       num_added += builder->AddWord(word, feature_labels, possible_labels);
123   }
124   VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
125           << vocab;
126 }
127
128 template <class Arc>
129 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
130               SymbolTable *osyms,
131               LinearClassifierFstDataBuilder<Arc> *builder) {
132   std::ifstream in(vocab);
133   if (!in) LOG(FATAL) << "Can't open file: " << vocab;
134   size_t num_line = 0, num_added = 0;
135   std::vector<string> fields;
136   std::vector<typename Arc::Label> feature_labels, possible_labels;
137   typename Arc::Label word;
138   while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
139                              &feature_labels, &possible_labels, &num_line)) {
140     if (!possible_labels.empty())
141       LOG(FATAL)
142           << "Classifier vocabulary should not have possible output constraint";
143     if (word == kNoLabel) {
144       LOG(WARNING) << "Ignored: boundary word: " << fields[0];
145       continue;
146     }
147     num_added += builder->AddWord(word, feature_labels);
148   }
149   VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
150           << vocab;
151 }
152
153 // Reads in model file. The first line is an integer designating the
154 // size of future window in the input sequences. After this, each line
155 // is in the following format
156 //
157 //   input sequence <whitespace> output sequence <whitespace> weight
158 //
159 // input sequence is a `FLAGS_delimiter`-delimited sequence of feature
160 // labels (see `AddVocab()`) . output sequence is a
161 // `FLAGS_delimiter`-delimited sequence of output labels where the
162 // last label is the output of the feature position before the history
163 // boundary.
164 template <class Arc>
165 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
166               LinearFstDataBuilder<Arc> *builder) {
167   std::ifstream in(model);
168   if (!in) LOG(FATAL) << "Can't open file: " << model;
169   string line;
170   std::getline(in, line);
171   if (!in) LOG(FATAL) << "Empty file: " << model;
172   size_t future_size;
173   {
174     std::istringstream strm(line);
175     strm >> future_size;
176     if (!strm) LOG(FATAL) << "Can't read future size: " << model;
177   }
178   size_t num_line = 1, num_added = 0;
179   const int group = builder->AddGroup(future_size);
180   CHECK_GE(group, 0);
181   VLOG(1) << "Group " << group << ": from " << model << "; future size is "
182           << future_size << ".";
183   // Add the rest of lines as a single feature group
184   std::vector<string> fields;
185   std::vector<typename Arc::Label> input_labels, output_labels;
186   typename Arc::Weight weight;
187   while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
188                              &output_labels, &weight, &num_line)) {
189     if (output_labels.empty())
190       LOG(FATAL) << "Empty output sequence in source " << model << ", line "
191                  << num_line;
192
193     const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
194                                          LinearFstData<Arc>::kEndOfSentence};
195
196     std::vector<typename Arc::Label> copy_input(input_labels.size()),
197         copy_output(output_labels.size());
198     for (int i = 0; i < 2; ++i) {
199       for (int j = 0; j < 2; ++j) {
200         size_t num_input_changes =
201             ReplaceCopy(input_labels.begin(), input_labels.end(),
202                         copy_input.begin(), kNoLabel, marks[i]);
203         size_t num_output_changes =
204             ReplaceCopy(output_labels.begin(), output_labels.end(),
205                         copy_output.begin(), kNoLabel, marks[j]);
206         if ((num_input_changes > 0 || i == 0) &&
207             (num_output_changes > 0 || j == 0))
208           num_added +=
209               builder->AddWeight(group, copy_input, copy_output, weight);
210       }
211     }
212   }
213   VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
214           << num_line << " lines.";
215 }
216
217 template <class Arc>
218 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
219               LinearClassifierFstDataBuilder<Arc> *builder) {
220   std::ifstream in(model);
221   if (!in) LOG(FATAL) << "Can't open file: " << model;
222   string line;
223   std::getline(in, line);
224   if (!in) LOG(FATAL) << "Empty file: " << model;
225   size_t future_size;
226   {
227     std::istringstream strm(line);
228     strm >> future_size;
229     if (!strm) LOG(FATAL) << "Can't read future size: " << model;
230   }
231   if (future_size != 0)
232     LOG(FATAL) << "Classifier model must have future size = 0; got "
233                << future_size << " from " << model;
234   size_t num_line = 1, num_added = 0;
235   const int group = builder->AddGroup();
236   CHECK_GE(group, 0);
237   VLOG(1) << "Group " << group << ": from " << model << "; future size is "
238           << future_size << ".";
239   // Add the rest of lines as a single feature group
240   std::vector<string> fields;
241   std::vector<typename Arc::Label> input_labels, output_labels;
242   typename Arc::Weight weight;
243   while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
244                              &output_labels, &weight, &num_line)) {
245     if (output_labels.size() != 1)
246       LOG(FATAL) << "Output not a single label in source " << model << ", line "
247                  << num_line;
248
249     const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
250                                          LinearFstData<Arc>::kEndOfSentence};
251
252     typename Arc::Label pred = output_labels[0];
253
254     std::vector<typename Arc::Label> copy_input(input_labels.size());
255     for (int i = 0; i < 2; ++i) {
256       size_t num_input_changes =
257           ReplaceCopy(input_labels.begin(), input_labels.end(),
258                       copy_input.begin(), kNoLabel, marks[i]);
259       if (num_input_changes > 0 || i == 0)
260         num_added += builder->AddWeight(group, copy_input, pred, weight);
261     }
262   }
263   VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
264           << num_line << " lines.";
265 }
266
267 void SplitByWhitespace(const string &str, std::vector<string> *out);
268 int ScanNumClasses(char **models, int models_length);
269
270 template <class Arc>
271 void LinearCompileTpl(LinearCompileArgs *args) {
272   const string &epsilon_symbol = args->arg1;
273   const string &unknown_symbol = args->arg2;
274   const string &vocab = args->arg3;
275   char **models = args->arg4;
276   const int models_length = args->arg5;
277   const string &out = args->arg6;
278   const string &save_isymbols = args->arg7;
279   const string &save_fsymbols = args->arg8;
280   const string &save_osymbols = args->arg9;
281
282   SymbolTable isyms,  // input (e.g. word tokens)
283       osyms,          // output (e.g. tags)
284       fsyms;          // feature (e.g. word identity, suffix, etc.)
285   isyms.AddSymbol(epsilon_symbol);
286   osyms.AddSymbol(epsilon_symbol);
287   fsyms.AddSymbol(epsilon_symbol);
288   isyms.AddSymbol(unknown_symbol);
289
290   VLOG(1) << "start-of-sentence label is "
291           << LinearFstData<Arc>::kStartOfSentence;
292   VLOG(1) << "end-of-sentence label is " << LinearFstData<Arc>::kEndOfSentence;
293
294   if (FLAGS_classifier) {
295     int num_classes = ScanNumClasses(models, models_length);
296     LinearClassifierFstDataBuilder<Arc> builder(num_classes, &isyms, &fsyms,
297                                                 &osyms);
298
299     AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
300     for (int i = 0; i < models_length; ++i)
301       AddModel(models[i], &fsyms, &osyms, &builder);
302
303     LinearClassifierFst<Arc> fst(builder.Dump(), num_classes, &isyms, &osyms);
304     fst.Write(out);
305   } else {
306     LinearFstDataBuilder<Arc> builder(&isyms, &fsyms, &osyms);
307
308     AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
309     for (int i = 0; i < models_length; ++i)
310       AddModel(models[i], &fsyms, &osyms, &builder);
311
312     LinearTaggerFst<Arc> fst(builder.Dump(), &isyms, &osyms);
313     fst.Write(out);
314   }
315
316   if (!save_isymbols.empty()) isyms.WriteText(save_isymbols);
317   if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols);
318   if (!save_osymbols.empty()) osyms.WriteText(save_osymbols);
319 }
320
321 void LinearCompile(const string &arc_type, const string &epsilon_symbol,
322                    const string &unknown_symbol, const string &vocab,
323                    char **models, int models_len, const string &out,
324                    const string &save_isymbols, const string &save_fsymbols,
325                    const string &save_osymbols);
326
327 template <class Arc>
328 bool GetVocabRecord(const string &vocab, std::istream &strm,  // NOLINT
329                     SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
330                     typename Arc::Label *word,
331                     std::vector<typename Arc::Label> *feature_labels,
332                     std::vector<typename Arc::Label> *possible_labels,
333                     size_t *num_line) {
334   string line;
335   if (!std::getline(strm, line)) return false;
336   ++(*num_line);
337
338   std::vector<string> fields;
339   SplitByWhitespace(line, &fields);
340   if (fields.size() != 3)
341     LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line "
342                << num_line;
343
344   feature_labels->clear();
345   possible_labels->clear();
346
347   *word = LookUp<Arc>(fields[0], isyms);
348
349   const char delim = FLAGS_delimiter[0];
350   SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
351   SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
352
353   return true;
354 }
355
356 template <class Arc>
357 bool GetModelRecord(const string &model, std::istream &strm,  // NOLINT
358                     SymbolTable *fsyms, SymbolTable *osyms,
359                     std::vector<typename Arc::Label> *input_labels,
360                     std::vector<typename Arc::Label> *output_labels,
361                     typename Arc::Weight *weight, size_t *num_line) {
362   string line;
363   if (!std::getline(strm, line)) return false;
364   ++(*num_line);
365
366   std::vector<string> fields;
367   SplitByWhitespace(line, &fields);
368   if (fields.size() != 3)
369     LOG(FATAL) << "Wrong number of fields in source " << model << ", line "
370                << num_line;
371
372   input_labels->clear();
373   output_labels->clear();
374
375   const char delim = FLAGS_delimiter[0];
376   SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
377   SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
378
379   *weight = StrToWeight<typename Arc::Weight>(fields[2], model, *num_line);
380
381   GuessStartOrEnd<Arc>(input_labels, kNoLabel);
382   GuessStartOrEnd<Arc>(output_labels, kNoLabel);
383
384   return true;
385 }
386 }  // namespace script
387 }  // namespace fst
388
389 #define REGISTER_FST_LINEAR_OPERATIONS(Arc) \
390   REGISTER_FST_OPERATION(LinearCompileTpl, Arc, LinearCompileArgs);
391
392 #endif  // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_