7f96a6d7fc993c3cfcd1dcad2dd5373fde81f6a4
[platform/upstream/openfst.git] / src / include / fst / string.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Utilities to convert strings into FSTs.
5
6 #ifndef FST_LIB_STRING_H_
7 #define FST_LIB_STRING_H_
8
9 #include <sstream>
10 #include <string>
11 #include <vector>
12
13 #include <fst/log.h>
14
15 #include <fst/compact-fst.h>
16 #include <fst/icu.h>
17 #include <fst/mutable-fst.h>
18 #include <fst/util.h>
19
20
21 DECLARE_string(fst_field_separator);
22
23 namespace fst {
24
25 // This will eventually replace StringCompiler<Arc>::TokenType and
26 // StringPrinter<Arc>::TokenType.
27 enum StringTokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
28
29 namespace internal {
30
31 template <class Label>
32 bool ConvertSymbolToLabel(const char *str, const SymbolTable *syms,
33                           Label unknown_label, bool allow_negative,
34                           Label *output) {
35   int64 n;
36   if (syms) {
37     n = syms->Find(str);
38     if ((n == -1) && (unknown_label != kNoLabel)) n = unknown_label;
39     if (n == -1 || (!allow_negative && n < 0)) {
40       VLOG(1) << "ConvertSymbolToLabel: Symbol \"" << str
41               << "\" is not mapped to any integer label, symbol table = "
42               << syms->Name();
43       return false;
44     }
45   } else {
46     char *p;
47     n = strtoll(str, &p, 10);
48     if (p < str + strlen(str) || (!allow_negative && n < 0)) {
49       VLOG(1) << "ConvertSymbolToLabel: Bad label integer "
50               << "= \"" << str << "\"";
51       return false;
52     }
53   }
54   *output = n;
55   return true;
56 }
57
58 template <class Label>
59 bool ConvertStringToLabels(const string &str, StringTokenType token_type,
60                            const SymbolTable *syms, Label unknown_label,
61                            bool allow_negative, std::vector<Label> *labels) {
62   labels->clear();
63   if (token_type == StringTokenType::BYTE) {
64     for (const char c : str) labels->push_back(c);
65   } else if (token_type == StringTokenType::UTF8) {
66     return UTF8StringToLabels(str, labels);
67   } else {
68     auto *c_str = new char[str.size() + 1];
69     str.copy(c_str, str.size());
70     c_str[str.size()] = 0;
71     std::vector<char *> vec;
72     const string separator = "\n" + FLAGS_fst_field_separator;
73     SplitToVector(c_str, separator.c_str(), &vec, true);
74     for (const char *c : vec) {
75       Label label;
76       if (!ConvertSymbolToLabel(c, syms, unknown_label, allow_negative,
77                                 &label)) {
78         delete[] c_str;
79         return false;
80       }
81       labels->push_back(label);
82     }
83     delete[] c_str;
84   }
85   return true;
86 }
87
88 }  // namespace internal
89
90 // Functor for compiling a string in an FST.
91 template <class Arc>
92 class StringCompiler {
93  public:
94   using Label = typename Arc::Label;
95   using StateId = typename Arc::StateId;
96   using Weight = typename Arc::Weight;
97
98   explicit StringCompiler(StringTokenType token_type,
99                           const SymbolTable *syms = nullptr,
100                           Label unknown_label = kNoLabel,
101                           bool allow_negative = false)
102       : token_type_(token_type),
103         syms_(syms),
104         unknown_label_(unknown_label),
105         allow_negative_(allow_negative) {}
106
107   enum OPENFST_DEPRECATED("Use fst::StringTokenType") TokenType {
108     SYMBOL = 1,
109     BYTE = 2,
110     UTF8 = 3
111   };
112
113   OPENFST_DEPRECATED("Use fst::StringTokenType")
114   explicit StringCompiler(TokenType token_type,
115                           const SymbolTable *syms = nullptr,
116                           Label unknown_label = kNoLabel,
117                           bool allow_negative = false)
118       : StringCompiler(static_cast<StringTokenType>(token_type), syms,
119                        unknown_label, allow_negative) {}
120
121   // Compiles string into an FST.
122
123   template <class FST>
124   bool operator()(const string &str, FST *fst) const {
125     std::vector<Label> labels;
126     if (!internal::ConvertStringToLabels(str, token_type_, syms_,
127                                          unknown_label_, allow_negative_,
128                                          &labels)) {
129       return false;
130     }
131     Compile(labels, fst);
132     return true;
133   }
134
135   template <class FST>
136   bool operator()(const string &str, FST *fst, Weight weight) const {
137     std::vector<Label> labels;
138     if (!internal::ConvertStringToLabels(str, token_type_, syms_,
139                                          unknown_label_, allow_negative_,
140                                          &labels)) {
141       return false;
142     }
143     Compile(labels, fst, std::move(weight));
144     return true;
145   }
146
147  private:
148   void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst,
149                Weight weight = Weight::One()) const {
150     fst->DeleteStates();
151     while (fst->NumStates() <= labels.size()) fst->AddState();
152     for (StateId i = 0; i < labels.size(); ++i) {
153       fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
154     }
155     fst->SetStart(0);
156     fst->SetFinal(labels.size(), std::move(weight));
157   }
158
159   template <class Unsigned>
160   void Compile(const std::vector<Label> &labels,
161                CompactStringFst<Arc, Unsigned> *fst) const {
162     fst->SetCompactElements(labels.begin(), labels.end());
163   }
164
165   template <class Unsigned>
166   void Compile(const std::vector<Label> &labels,
167                CompactWeightedStringFst<Arc, Unsigned> *fst,
168                const Weight &weight = Weight::One()) const {
169     std::vector<std::pair<Label, Weight>> compacts;
170     compacts.reserve(labels.size() + 1);
171     for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) {
172       compacts.emplace_back(labels[i], Weight::One());
173     }
174     compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight);
175     fst->SetCompactElements(compacts.begin(), compacts.end());
176   }
177
178   const StringTokenType token_type_;
179   const SymbolTable *syms_;    // Symbol table (used when token type is symbol).
180   const Label unknown_label_;  // Label for token missing from symbol table.
181   const bool allow_negative_;  // Negative labels allowed?
182
183   StringCompiler(const StringCompiler &) = delete;
184   StringCompiler &operator=(const StringCompiler &) = delete;
185 };
186
187 // Functor for printing a string FST as a string.
188 template <class Arc>
189 class StringPrinter {
190  public:
191   using Label = typename Arc::Label;
192   using StateId = typename Arc::StateId;
193   using Weight = typename Arc::Weight;
194
195   explicit StringPrinter(StringTokenType token_type,
196                          const SymbolTable *syms = nullptr)
197       : token_type_(token_type), syms_(syms) {}
198
199   enum OPENFST_DEPRECATED("Use fst::StringTokenType") TokenType {
200     SYMBOL = 1,
201     BYTE = 2,
202     UTF8 = 3
203   };
204
205   OPENFST_DEPRECATED("Use fst::StringTokenType")
206   explicit StringPrinter(TokenType token_type,
207                          const SymbolTable *syms = nullptr)
208       : StringPrinter(static_cast<StringTokenType>(token_type), syms) {}
209
210   // Converts the FST into a string.
211   bool operator()(const Fst<Arc> &fst, string *result) {
212     if (!FstToLabels(fst)) {
213       VLOG(1) << "StringPrinter::operator(): FST is not a string";
214       return false;
215     }
216     result->clear();
217     if (token_type_ == StringTokenType::SYMBOL) {
218       std::stringstream sstrm;
219       for (size_t i = 0; i < labels_.size(); ++i) {
220         if (i) sstrm << *(FLAGS_fst_field_separator.rbegin());
221         if (!PrintLabel(labels_[i], sstrm)) return false;
222       }
223       *result = sstrm.str();
224     } else if (token_type_ == StringTokenType::BYTE) {
225       result->reserve(labels_.size());
226       for (size_t i = 0; i < labels_.size(); ++i) result->push_back(labels_[i]);
227     } else if (token_type_ == StringTokenType::UTF8) {
228       return LabelsToUTF8String(labels_, result);
229     } else {
230       VLOG(1) << "StringPrinter::operator(): Unknown token type: "
231               << token_type_;
232       return false;
233     }
234     return true;
235   }
236
237  private:
238   bool FstToLabels(const Fst<Arc> &fst) {
239     labels_.clear();
240     auto s = fst.Start();
241     if (s == kNoStateId) {
242       VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
243               << "string FST";
244       return false;
245     }
246     while (fst.Final(s) == Weight::Zero()) {
247       ArcIterator<Fst<Arc>> aiter(fst, s);
248       if (aiter.Done()) {
249         VLOG(2) << "StringPrinter::FstToLabels: String FST traversal does "
250                 << "not reach final state";
251         return false;
252       }
253       const auto &arc = aiter.Value();
254       labels_.push_back(arc.olabel);
255       s = arc.nextstate;
256       if (s == kNoStateId) {
257         VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid state";
258         return false;
259       }
260       aiter.Next();
261       if (!aiter.Done()) {
262         VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
263                 << "outgoing arcs found";
264         return false;
265       }
266     }
267     return true;
268   }
269
270   bool PrintLabel(Label label, std::ostream &ostrm) {
271     if (syms_) {
272       const auto symbol = syms_->Find(label);
273       if (symbol == "") {
274         VLOG(2) << "StringPrinter::PrintLabel: Integer " << label << " is not "
275                 << "mapped to any textual symbol, symbol table = "
276                 << syms_->Name();
277         return false;
278       }
279       ostrm << symbol;
280     } else {
281       ostrm << label;
282     }
283     return true;
284   }
285
286   const StringTokenType token_type_;
287   const SymbolTable *syms_;    // Symbol table (used when token type is symbol).
288   std::vector<Label> labels_;  // Input FST labels.
289
290   StringPrinter(const StringPrinter &) = delete;
291   StringPrinter &operator=(const StringPrinter &) = delete;
292 };
293
294 }  // namespace fst
295
296 #endif  // FST_LIB_STRING_H_