1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
5 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
12 #include <fst/compat.h>
13 #include <fst/extensions/linear/linear-fst-data-builder.h>
14 #include <fst/extensions/linear/linear-fst.h>
16 #include <fst/symbol-table.h>
17 #include <fst/script/arg-packs.h>
18 #include <fst/script/script-impl.h>
20 DECLARE_string(delimiter);
21 DECLARE_string(empty_symbol);
22 DECLARE_string(start_symbol);
23 DECLARE_string(end_symbol);
24 DECLARE_bool(classifier);
28 typedef args::Package<const string &, const string &, const string &, char **,
29 int, const string &, const string &, const string &,
30 const string &> LinearCompileArgs;
32 bool ValidateDelimiter();
33 bool ValidateEmptySymbol();
35 // Returns the proper label given the symbol. For symbols other than
36 // `FLAGS_start_symbol` or `FLAGS_end_symbol`, looks up the symbol
37 // table to decide the label. Depending on whether
38 // `FLAGS_start_symbol` and `FLAGS_end_symbol` are identical, it
39 // either returns `kNoLabel` for later processing or decides the label
42 inline typename Arc::Label LookUp(const string &str, SymbolTable *syms) {
43 if (str == FLAGS_start_symbol)
44 return str == FLAGS_end_symbol ? kNoLabel
45 : LinearFstData<Arc>::kStartOfSentence;
46 else if (str == FLAGS_end_symbol)
47 return LinearFstData<Arc>::kEndOfSentence;
49 return syms->AddSymbol(str);
52 // Splits `str` with `delim` as the delimiter and stores the labels in
55 void SplitAndPush(const string &str, const char delim, SymbolTable *syms,
56 std::vector<typename Arc::Label> *output) {
57 if (str == FLAGS_empty_symbol) return;
58 std::istringstream strm(str);
60 while (std::getline(strm, buf, delim))
61 output->push_back(LookUp<Arc>(buf, syms));
64 // Like `std::replace_copy` but returns the number of modifications
65 template <class InputIterator, class OutputIterator, class T>
66 size_t ReplaceCopy(InputIterator first, InputIterator last,
67 OutputIterator result, const T &old_value,
70 while (first != last) {
71 if (*first == old_value) {
84 bool GetVocabRecord(const string &vocab, std::istream &strm, // NOLINT
85 SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
86 typename Arc::Label *word,
87 std::vector<typename Arc::Label> *feature_labels,
88 std::vector<typename Arc::Label> *possible_labels,
92 bool GetModelRecord(const string &model, std::istream &strm, // NOLINT
93 SymbolTable *fsyms, SymbolTable *osyms,
94 std::vector<typename Arc::Label> *input_labels,
95 std::vector<typename Arc::Label> *output_labels,
96 typename Arc::Weight *weight, size_t *num_line);
98 // Reads in vocabulary file. Each line is in the following format
100 // word <whitespace> features [ <whitespace> possible output ]
102 // where features and possible output are `FLAGS_delimiter`-delimited lists of
105 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
106 SymbolTable *osyms, LinearFstDataBuilder<Arc> *builder) {
107 std::ifstream in(vocab);
108 if (!in) LOG(FATAL) << "Can't open file: " << vocab;
109 size_t num_line = 0, num_added = 0;
110 std::vector<string> fields;
111 std::vector<typename Arc::Label> feature_labels, possible_labels;
112 typename Arc::Label word;
113 while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
114 &feature_labels, &possible_labels, &num_line)) {
115 if (word == kNoLabel) {
116 LOG(WARNING) << "Ignored: boundary word: " << fields[0];
119 if (possible_labels.empty())
120 num_added += builder->AddWord(word, feature_labels);
122 num_added += builder->AddWord(word, feature_labels, possible_labels);
124 VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
129 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
131 LinearClassifierFstDataBuilder<Arc> *builder) {
132 std::ifstream in(vocab);
133 if (!in) LOG(FATAL) << "Can't open file: " << vocab;
134 size_t num_line = 0, num_added = 0;
135 std::vector<string> fields;
136 std::vector<typename Arc::Label> feature_labels, possible_labels;
137 typename Arc::Label word;
138 while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
139 &feature_labels, &possible_labels, &num_line)) {
140 if (!possible_labels.empty())
142 << "Classifier vocabulary should not have possible output constraint";
143 if (word == kNoLabel) {
144 LOG(WARNING) << "Ignored: boundary word: " << fields[0];
147 num_added += builder->AddWord(word, feature_labels);
149 VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
153 // Reads in model file. The first line is an integer designating the
154 // size of future window in the input sequences. After this, each line
155 // is in the following format
157 // input sequence <whitespace> output sequence <whitespace> weight
159 // input sequence is a `FLAGS_delimiter`-delimited sequence of feature
160 // labels (see `AddVocab()`) . output sequence is a
161 // `FLAGS_delimiter`-delimited sequence of output labels where the
162 // last label is the output of the feature position before the history
165 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
166 LinearFstDataBuilder<Arc> *builder) {
167 std::ifstream in(model);
168 if (!in) LOG(FATAL) << "Can't open file: " << model;
170 std::getline(in, line);
171 if (!in) LOG(FATAL) << "Empty file: " << model;
174 std::istringstream strm(line);
176 if (!strm) LOG(FATAL) << "Can't read future size: " << model;
178 size_t num_line = 1, num_added = 0;
179 const int group = builder->AddGroup(future_size);
181 VLOG(1) << "Group " << group << ": from " << model << "; future size is "
182 << future_size << ".";
183 // Add the rest of lines as a single feature group
184 std::vector<string> fields;
185 std::vector<typename Arc::Label> input_labels, output_labels;
186 typename Arc::Weight weight;
187 while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
188 &output_labels, &weight, &num_line)) {
189 if (output_labels.empty())
190 LOG(FATAL) << "Empty output sequence in source " << model << ", line "
193 const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
194 LinearFstData<Arc>::kEndOfSentence};
196 std::vector<typename Arc::Label> copy_input(input_labels.size()),
197 copy_output(output_labels.size());
198 for (int i = 0; i < 2; ++i) {
199 for (int j = 0; j < 2; ++j) {
200 size_t num_input_changes =
201 ReplaceCopy(input_labels.begin(), input_labels.end(),
202 copy_input.begin(), kNoLabel, marks[i]);
203 size_t num_output_changes =
204 ReplaceCopy(output_labels.begin(), output_labels.end(),
205 copy_output.begin(), kNoLabel, marks[j]);
206 if ((num_input_changes > 0 || i == 0) &&
207 (num_output_changes > 0 || j == 0))
209 builder->AddWeight(group, copy_input, copy_output, weight);
213 VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
214 << num_line << " lines.";
218 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
219 LinearClassifierFstDataBuilder<Arc> *builder) {
220 std::ifstream in(model);
221 if (!in) LOG(FATAL) << "Can't open file: " << model;
223 std::getline(in, line);
224 if (!in) LOG(FATAL) << "Empty file: " << model;
227 std::istringstream strm(line);
229 if (!strm) LOG(FATAL) << "Can't read future size: " << model;
231 if (future_size != 0)
232 LOG(FATAL) << "Classifier model must have future size = 0; got "
233 << future_size << " from " << model;
234 size_t num_line = 1, num_added = 0;
235 const int group = builder->AddGroup();
237 VLOG(1) << "Group " << group << ": from " << model << "; future size is "
238 << future_size << ".";
239 // Add the rest of lines as a single feature group
240 std::vector<string> fields;
241 std::vector<typename Arc::Label> input_labels, output_labels;
242 typename Arc::Weight weight;
243 while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
244 &output_labels, &weight, &num_line)) {
245 if (output_labels.size() != 1)
246 LOG(FATAL) << "Output not a single label in source " << model << ", line "
249 const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
250 LinearFstData<Arc>::kEndOfSentence};
252 typename Arc::Label pred = output_labels[0];
254 std::vector<typename Arc::Label> copy_input(input_labels.size());
255 for (int i = 0; i < 2; ++i) {
256 size_t num_input_changes =
257 ReplaceCopy(input_labels.begin(), input_labels.end(),
258 copy_input.begin(), kNoLabel, marks[i]);
259 if (num_input_changes > 0 || i == 0)
260 num_added += builder->AddWeight(group, copy_input, pred, weight);
263 VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
264 << num_line << " lines.";
267 void SplitByWhitespace(const string &str, std::vector<string> *out);
268 int ScanNumClasses(char **models, int models_length);
271 void LinearCompileTpl(LinearCompileArgs *args) {
272 const string &epsilon_symbol = args->arg1;
273 const string &unknown_symbol = args->arg2;
274 const string &vocab = args->arg3;
275 char **models = args->arg4;
276 const int models_length = args->arg5;
277 const string &out = args->arg6;
278 const string &save_isymbols = args->arg7;
279 const string &save_fsymbols = args->arg8;
280 const string &save_osymbols = args->arg9;
282 SymbolTable isyms, // input (e.g. word tokens)
283 osyms, // output (e.g. tags)
284 fsyms; // feature (e.g. word identity, suffix, etc.)
285 isyms.AddSymbol(epsilon_symbol);
286 osyms.AddSymbol(epsilon_symbol);
287 fsyms.AddSymbol(epsilon_symbol);
288 isyms.AddSymbol(unknown_symbol);
290 VLOG(1) << "start-of-sentence label is "
291 << LinearFstData<Arc>::kStartOfSentence;
292 VLOG(1) << "end-of-sentence label is " << LinearFstData<Arc>::kEndOfSentence;
294 if (FLAGS_classifier) {
295 int num_classes = ScanNumClasses(models, models_length);
296 LinearClassifierFstDataBuilder<Arc> builder(num_classes, &isyms, &fsyms,
299 AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
300 for (int i = 0; i < models_length; ++i)
301 AddModel(models[i], &fsyms, &osyms, &builder);
303 LinearClassifierFst<Arc> fst(builder.Dump(), num_classes, &isyms, &osyms);
306 LinearFstDataBuilder<Arc> builder(&isyms, &fsyms, &osyms);
308 AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
309 for (int i = 0; i < models_length; ++i)
310 AddModel(models[i], &fsyms, &osyms, &builder);
312 LinearTaggerFst<Arc> fst(builder.Dump(), &isyms, &osyms);
316 if (!save_isymbols.empty()) isyms.WriteText(save_isymbols);
317 if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols);
318 if (!save_osymbols.empty()) osyms.WriteText(save_osymbols);
321 void LinearCompile(const string &arc_type, const string &epsilon_symbol,
322 const string &unknown_symbol, const string &vocab,
323 char **models, int models_len, const string &out,
324 const string &save_isymbols, const string &save_fsymbols,
325 const string &save_osymbols);
328 bool GetVocabRecord(const string &vocab, std::istream &strm, // NOLINT
329 SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
330 typename Arc::Label *word,
331 std::vector<typename Arc::Label> *feature_labels,
332 std::vector<typename Arc::Label> *possible_labels,
335 if (!std::getline(strm, line)) return false;
338 std::vector<string> fields;
339 SplitByWhitespace(line, &fields);
340 if (fields.size() != 3)
341 LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line "
344 feature_labels->clear();
345 possible_labels->clear();
347 *word = LookUp<Arc>(fields[0], isyms);
349 const char delim = FLAGS_delimiter[0];
350 SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
351 SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
357 bool GetModelRecord(const string &model, std::istream &strm, // NOLINT
358 SymbolTable *fsyms, SymbolTable *osyms,
359 std::vector<typename Arc::Label> *input_labels,
360 std::vector<typename Arc::Label> *output_labels,
361 typename Arc::Weight *weight, size_t *num_line) {
363 if (!std::getline(strm, line)) return false;
366 std::vector<string> fields;
367 SplitByWhitespace(line, &fields);
368 if (fields.size() != 3)
369 LOG(FATAL) << "Wrong number of fields in source " << model << ", line "
372 input_labels->clear();
373 output_labels->clear();
375 const char delim = FLAGS_delimiter[0];
376 SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
377 SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
379 *weight = StrToWeight<typename Arc::Weight>(fields[2], model, *num_line);
381 GuessStartOrEnd<Arc>(input_labels, kNoLabel);
382 GuessStartOrEnd<Arc>(output_labels, kNoLabel);
386 } // namespace script
389 #define REGISTER_FST_LINEAR_OPERATIONS(Arc) \
390 REGISTER_FST_OPERATION(LinearCompileTpl, Arc, LinearCompileArgs);
392 #endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_