Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / string.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Utilities to convert strings into FSTs.
5
6 #ifndef FST_LIB_STRING_H_
7 #define FST_LIB_STRING_H_
8
9 #include <memory>
10 #include <sstream>
11 #include <string>
12 #include <vector>
13
14 #include <fst/log.h>
15
16 #include <fst/compact-fst.h>
17 #include <fst/icu.h>
18 #include <fst/mutable-fst.h>
19 #include <fst/util.h>
20
21
22 DECLARE_string(fst_field_separator);
23
24 namespace fst {
25
26 // This will eventually replace StringCompiler<Arc>::TokenType and
27 // StringPrinter<Arc>::TokenType.
28 enum StringTokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
29
30 namespace internal {
31
32 template <class Label>
33 bool ConvertSymbolToLabel(const char *str, const SymbolTable *syms,
34                           Label unknown_label, bool allow_negative,
35                           Label *output) {
36   int64 n;
37   if (syms) {
38     n = syms->Find(str);
39     if ((n == -1) && (unknown_label != kNoLabel)) n = unknown_label;
40     if (n == -1 || (!allow_negative && n < 0)) {
41       VLOG(1) << "ConvertSymbolToLabel: Symbol \"" << str
42               << "\" is not mapped to any integer label, symbol table = "
43               << syms->Name();
44       return false;
45     }
46   } else {
47     char *p;
48     n = strtoll(str, &p, 10);
49     if (p < str + strlen(str) || (!allow_negative && n < 0)) {
50       VLOG(1) << "ConvertSymbolToLabel: Bad label integer "
51               << "= \"" << str << "\"";
52       return false;
53     }
54   }
55   *output = n;
56   return true;
57 }
58
59 template <class Label>
60 bool ConvertStringToLabels(const string &str, StringTokenType token_type,
61                            const SymbolTable *syms, Label unknown_label,
62                            bool allow_negative, std::vector<Label> *labels) {
63   labels->clear();
64   if (token_type == StringTokenType::BYTE) {
65     for (const char c : str) labels->push_back(c);
66   } else if (token_type == StringTokenType::UTF8) {
67     return UTF8StringToLabels(str, labels);
68   } else {
69     std::unique_ptr<char[]> c_str(new char[str.size() + 1]);
70     str.copy(c_str.get(), str.size());
71     c_str[str.size()] = 0;
72     std::vector<char *> vec;
73     const string separator = "\n" + FLAGS_fst_field_separator;
74     SplitToVector(c_str.get(), separator.c_str(), &vec, true);
75     for (const char *c : vec) {
76       Label label;
77       if (!ConvertSymbolToLabel(c, syms, unknown_label, allow_negative,
78                                 &label)) {
79         return false;
80       }
81       labels->push_back(label);
82     }
83   }
84   return true;
85 }
86
87 }  // namespace internal
88
89 // Functor for compiling a string in an FST.
90 template <class Arc>
91 class StringCompiler {
92  public:
93   using Label = typename Arc::Label;
94   using StateId = typename Arc::StateId;
95   using Weight = typename Arc::Weight;
96
97   explicit StringCompiler(StringTokenType token_type,
98                           const SymbolTable *syms = nullptr,
99                           Label unknown_label = kNoLabel,
100                           bool allow_negative = false)
101       : token_type_(token_type),
102         syms_(syms),
103         unknown_label_(unknown_label),
104         allow_negative_(allow_negative) {}
105
106   // Compiles string into an FST.
107   template <class FST>
108   bool operator()(const string &str, FST *fst) const {
109     std::vector<Label> labels;
110     if (!internal::ConvertStringToLabels(str, token_type_, syms_,
111                                          unknown_label_, allow_negative_,
112                                          &labels)) {
113       return false;
114     }
115     Compile(labels, fst);
116     return true;
117   }
118
119   template <class FST>
120   bool operator()(const string &str, FST *fst, Weight weight) const {
121     std::vector<Label> labels;
122     if (!internal::ConvertStringToLabels(str, token_type_, syms_,
123                                          unknown_label_, allow_negative_,
124                                          &labels)) {
125       return false;
126     }
127     Compile(labels, fst, std::move(weight));
128     return true;
129   }
130
131  private:
132   void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst,
133                Weight weight = Weight::One()) const {
134     fst->DeleteStates();
135     while (fst->NumStates() <= labels.size()) fst->AddState();
136     for (StateId i = 0; i < labels.size(); ++i) {
137       fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
138     }
139     fst->SetStart(0);
140     fst->SetFinal(labels.size(), std::move(weight));
141   }
142
143   template <class Unsigned>
144   void Compile(const std::vector<Label> &labels,
145                CompactStringFst<Arc, Unsigned> *fst) const {
146     fst->SetCompactElements(labels.begin(), labels.end());
147   }
148
149   template <class Unsigned>
150   void Compile(const std::vector<Label> &labels,
151                CompactWeightedStringFst<Arc, Unsigned> *fst,
152                const Weight &weight = Weight::One()) const {
153     std::vector<std::pair<Label, Weight>> compacts;
154     compacts.reserve(labels.size() + 1);
155     for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) {
156       compacts.emplace_back(labels[i], Weight::One());
157     }
158     compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight);
159     fst->SetCompactElements(compacts.begin(), compacts.end());
160   }
161
162   const StringTokenType token_type_;
163   const SymbolTable *syms_;    // Symbol table (used when token type is symbol).
164   const Label unknown_label_;  // Label for token missing from symbol table.
165   const bool allow_negative_;  // Negative labels allowed?
166
167   StringCompiler(const StringCompiler &) = delete;
168   StringCompiler &operator=(const StringCompiler &) = delete;
169 };
170
171 // Functor for printing a string FST as a string.
172 template <class Arc>
173 class StringPrinter {
174  public:
175   using Label = typename Arc::Label;
176   using StateId = typename Arc::StateId;
177   using Weight = typename Arc::Weight;
178
179   explicit StringPrinter(StringTokenType token_type,
180                          const SymbolTable *syms = nullptr)
181       : token_type_(token_type), syms_(syms) {}
182
183   // Converts the FST into a string.
184   bool operator()(const Fst<Arc> &fst, string *result) {
185     if (!FstToLabels(fst)) {
186       VLOG(1) << "StringPrinter::operator(): FST is not a string";
187       return false;
188     }
189     result->clear();
190     if (token_type_ == StringTokenType::SYMBOL) {
191       std::stringstream sstrm;
192       for (size_t i = 0; i < labels_.size(); ++i) {
193         if (i) sstrm << *(FLAGS_fst_field_separator.rbegin());
194         if (!PrintLabel(labels_[i], sstrm)) return false;
195       }
196       *result = sstrm.str();
197     } else if (token_type_ == StringTokenType::BYTE) {
198       result->reserve(labels_.size());
199       for (size_t i = 0; i < labels_.size(); ++i) result->push_back(labels_[i]);
200     } else if (token_type_ == StringTokenType::UTF8) {
201       return LabelsToUTF8String(labels_, result);
202     } else {
203       VLOG(1) << "StringPrinter::operator(): Unknown token type: "
204               << token_type_;
205       return false;
206     }
207     return true;
208   }
209
210  private:
211   bool FstToLabels(const Fst<Arc> &fst) {
212     labels_.clear();
213     auto s = fst.Start();
214     if (s == kNoStateId) {
215       VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
216               << "string FST";
217       return false;
218     }
219     while (fst.Final(s) == Weight::Zero()) {
220       ArcIterator<Fst<Arc>> aiter(fst, s);
221       if (aiter.Done()) {
222         VLOG(2) << "StringPrinter::FstToLabels: String FST traversal does "
223                 << "not reach final state";
224         return false;
225       }
226       const auto &arc = aiter.Value();
227       labels_.push_back(arc.olabel);
228       s = arc.nextstate;
229       if (s == kNoStateId) {
230         VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid state";
231         return false;
232       }
233       aiter.Next();
234       if (!aiter.Done()) {
235         VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
236                 << "outgoing arcs found";
237         return false;
238       }
239     }
240     return true;
241   }
242
243   bool PrintLabel(Label label, std::ostream &ostrm) {
244     if (syms_) {
245       const auto symbol = syms_->Find(label);
246       if (symbol == "") {
247         VLOG(2) << "StringPrinter::PrintLabel: Integer " << label << " is not "
248                 << "mapped to any textual symbol, symbol table = "
249                 << syms_->Name();
250         return false;
251       }
252       ostrm << symbol;
253     } else {
254       ostrm << label;
255     }
256     return true;
257   }
258
259   const StringTokenType token_type_;
260   const SymbolTable *syms_;    // Symbol table (used when token type is symbol).
261   std::vector<Label> labels_;  // Input FST labels.
262
263   StringPrinter(const StringPrinter &) = delete;
264   StringPrinter &operator=(const StringPrinter &) = delete;
265 };
266
267 }  // namespace fst
268
269 #endif  // FST_LIB_STRING_H_