9bc8e43a3034453c906a0eabfbf07a6d65e191da
[platform/upstream/openfst.git] / src / include / fst / string.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Utilities to convert strings into FSTs.
5
6 #ifndef FST_STRING_H_
7 #define FST_STRING_H_
8
9 #include <memory>
10 #include <sstream>
11 #include <string>
12 #include <vector>
13
14 #include <fst/flags.h>
15 #include <fst/log.h>
16
17 #include <fst/compact-fst.h>
18 #include <fst/icu.h>
19 #include <fst/mutable-fst.h>
20 #include <fst/util.h>
21
22
23 DECLARE_string(fst_field_separator);
24
25 namespace fst {
26
27 // This will eventually replace StringCompiler<Arc>::TokenType and
28 // StringPrinter<Arc>::TokenType.
29 enum StringTokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
30
31 namespace internal {
32
33 template <class Label>
34 bool ConvertSymbolToLabel(const char *str, const SymbolTable *syms,
35                           Label unknown_label, bool allow_negative,
36                           Label *output) {
37   int64 n;
38   if (syms) {
39     n = syms->Find(str);
40     if ((n == -1) && (unknown_label != kNoLabel)) n = unknown_label;
41     if (n == -1 || (!allow_negative && n < 0)) {
42       VLOG(1) << "ConvertSymbolToLabel: Symbol \"" << str
43               << "\" is not mapped to any integer label, symbol table = "
44               << syms->Name();
45       return false;
46     }
47   } else {
48     char *p;
49     n = strtoll(str, &p, 10);
50     if (p < str + strlen(str) || (!allow_negative && n < 0)) {
51       VLOG(1) << "ConvertSymbolToLabel: Bad label integer "
52               << "= \"" << str << "\"";
53       return false;
54     }
55   }
56   *output = n;
57   return true;
58 }
59
60 template <class Label>
61 bool ConvertStringToLabels(const string &str, StringTokenType token_type,
62                            const SymbolTable *syms, Label unknown_label,
63                            bool allow_negative, std::vector<Label> *labels) {
64   labels->clear();
65   if (token_type == StringTokenType::BYTE) {
66     for (const char c : str) labels->push_back(c);
67   } else if (token_type == StringTokenType::UTF8) {
68     return UTF8StringToLabels(str, labels);
69   } else {
70     std::unique_ptr<char[]> c_str(new char[str.size() + 1]);
71     str.copy(c_str.get(), str.size());
72     c_str[str.size()] = 0;
73     std::vector<char *> vec;
74     const string separator = "\n" + FLAGS_fst_field_separator;
75     SplitToVector(c_str.get(), separator.c_str(), &vec, true);
76     for (const char *c : vec) {
77       Label label;
78       if (!ConvertSymbolToLabel(c, syms, unknown_label, allow_negative,
79                                 &label)) {
80         return false;
81       }
82       labels->push_back(label);
83     }
84   }
85   return true;
86 }
87
88 }  // namespace internal
89
90 // Functor for compiling a string in an FST.
91 template <class Arc>
92 class StringCompiler {
93  public:
94   using Label = typename Arc::Label;
95   using StateId = typename Arc::StateId;
96   using Weight = typename Arc::Weight;
97
98   explicit StringCompiler(StringTokenType token_type,
99                           const SymbolTable *syms = nullptr,
100                           Label unknown_label = kNoLabel,
101                           bool allow_negative = false)
102       : token_type_(token_type),
103         syms_(syms),
104         unknown_label_(unknown_label),
105         allow_negative_(allow_negative) {}
106
107   // Compiles string into an FST.
108   template <class FST>
109   bool operator()(const string &str, FST *fst) const {
110     std::vector<Label> labels;
111     if (!internal::ConvertStringToLabels(str, token_type_, syms_,
112                                          unknown_label_, allow_negative_,
113                                          &labels)) {
114       return false;
115     }
116     Compile(labels, fst);
117     return true;
118   }
119
120   template <class FST>
121   bool operator()(const string &str, FST *fst, Weight weight) const {
122     std::vector<Label> labels;
123     if (!internal::ConvertStringToLabels(str, token_type_, syms_,
124                                          unknown_label_, allow_negative_,
125                                          &labels)) {
126       return false;
127     }
128     Compile(labels, fst, std::move(weight));
129     return true;
130   }
131
132  private:
133   void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst,
134                Weight weight = Weight::One()) const {
135     fst->DeleteStates();
136     while (fst->NumStates() <= labels.size()) fst->AddState();
137     for (StateId i = 0; i < labels.size(); ++i) {
138       fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
139     }
140     fst->SetStart(0);
141     fst->SetFinal(labels.size(), std::move(weight));
142   }
143
144   template <class Unsigned>
145   void Compile(const std::vector<Label> &labels,
146                CompactStringFst<Arc, Unsigned> *fst) const {
147     fst->SetCompactElements(labels.begin(), labels.end());
148   }
149
150   template <class Unsigned>
151   void Compile(const std::vector<Label> &labels,
152                CompactWeightedStringFst<Arc, Unsigned> *fst,
153                const Weight &weight = Weight::One()) const {
154     std::vector<std::pair<Label, Weight>> compacts;
155     compacts.reserve(labels.size() + 1);
156     for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) {
157       compacts.emplace_back(labels[i], Weight::One());
158     }
159     compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight);
160     fst->SetCompactElements(compacts.begin(), compacts.end());
161   }
162
163   const StringTokenType token_type_;
164   const SymbolTable *syms_;    // Symbol table (used when token type is symbol).
165   const Label unknown_label_;  // Label for token missing from symbol table.
166   const bool allow_negative_;  // Negative labels allowed?
167
168   StringCompiler(const StringCompiler &) = delete;
169   StringCompiler &operator=(const StringCompiler &) = delete;
170 };
171
172 // Functor for printing a string FST as a string.
173 template <class Arc>
174 class StringPrinter {
175  public:
176   using Label = typename Arc::Label;
177   using StateId = typename Arc::StateId;
178   using Weight = typename Arc::Weight;
179
180   explicit StringPrinter(StringTokenType token_type,
181                          const SymbolTable *syms = nullptr)
182       : token_type_(token_type), syms_(syms) {}
183
184   // Converts the FST into a string.
185   bool operator()(const Fst<Arc> &fst, string *result) {
186     if (!FstToLabels(fst)) {
187       VLOG(1) << "StringPrinter::operator(): FST is not a string";
188       return false;
189     }
190     result->clear();
191     if (token_type_ == StringTokenType::SYMBOL) {
192       std::stringstream sstrm;
193       for (size_t i = 0; i < labels_.size(); ++i) {
194         if (i) sstrm << *(FLAGS_fst_field_separator.rbegin());
195         if (!PrintLabel(labels_[i], sstrm)) return false;
196       }
197       *result = sstrm.str();
198     } else if (token_type_ == StringTokenType::BYTE) {
199       result->reserve(labels_.size());
200       for (size_t i = 0; i < labels_.size(); ++i) result->push_back(labels_[i]);
201     } else if (token_type_ == StringTokenType::UTF8) {
202       return LabelsToUTF8String(labels_, result);
203     } else {
204       VLOG(1) << "StringPrinter::operator(): Unknown token type: "
205               << token_type_;
206       return false;
207     }
208     return true;
209   }
210
211  private:
212   bool FstToLabels(const Fst<Arc> &fst) {
213     labels_.clear();
214     auto s = fst.Start();
215     if (s == kNoStateId) {
216       VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
217               << "string FST";
218       return false;
219     }
220     while (fst.Final(s) == Weight::Zero()) {
221       ArcIterator<Fst<Arc>> aiter(fst, s);
222       if (aiter.Done()) {
223         VLOG(2) << "StringPrinter::FstToLabels: String FST traversal does "
224                 << "not reach final state";
225         return false;
226       }
227       const auto &arc = aiter.Value();
228       labels_.push_back(arc.olabel);
229       s = arc.nextstate;
230       if (s == kNoStateId) {
231         VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid state";
232         return false;
233       }
234       aiter.Next();
235       if (!aiter.Done()) {
236         VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
237                 << "outgoing arcs found";
238         return false;
239       }
240     }
241     return true;
242   }
243
244   bool PrintLabel(Label label, std::ostream &ostrm) {
245     if (syms_) {
246       const auto symbol = syms_->Find(label);
247       if (symbol == "") {
248         VLOG(2) << "StringPrinter::PrintLabel: Integer " << label << " is not "
249                 << "mapped to any textual symbol, symbol table = "
250                 << syms_->Name();
251         return false;
252       }
253       ostrm << symbol;
254     } else {
255       ostrm << label;
256     }
257     return true;
258   }
259
260   const StringTokenType token_type_;
261   const SymbolTable *syms_;    // Symbol table (used when token type is symbol).
262   std::vector<Label> labels_;  // Input FST labels.
263
264   StringPrinter(const StringPrinter &) = delete;
265   StringPrinter &operator=(const StringPrinter &) = delete;
266 };
267
268 }  // namespace fst
269
270 #endif  // FST_STRING_H_