1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Utilities to convert strings into FSTs.
6 #ifndef FST_LIB_STRING_H_
7 #define FST_LIB_STRING_H_
16 #include <fst/compact-fst.h>
18 #include <fst/mutable-fst.h>
22 DECLARE_string(fst_field_separator);
26 // This will eventually replace StringCompiler<Arc>::TokenType and
27 // StringPrinter<Arc>::TokenType.
28 enum StringTokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
32 template <class Label>
33 bool ConvertSymbolToLabel(const char *str, const SymbolTable *syms,
34 Label unknown_label, bool allow_negative,
39 if ((n == -1) && (unknown_label != kNoLabel)) n = unknown_label;
40 if (n == -1 || (!allow_negative && n < 0)) {
41 VLOG(1) << "ConvertSymbolToLabel: Symbol \"" << str
42 << "\" is not mapped to any integer label, symbol table = "
48 n = strtoll(str, &p, 10);
49 if (p < str + strlen(str) || (!allow_negative && n < 0)) {
50 VLOG(1) << "ConvertSymbolToLabel: Bad label integer "
51 << "= \"" << str << "\"";
59 template <class Label>
60 bool ConvertStringToLabels(const string &str, StringTokenType token_type,
61 const SymbolTable *syms, Label unknown_label,
62 bool allow_negative, std::vector<Label> *labels) {
64 if (token_type == StringTokenType::BYTE) {
65 for (const char c : str) labels->push_back(c);
66 } else if (token_type == StringTokenType::UTF8) {
67 return UTF8StringToLabels(str, labels);
69 std::unique_ptr<char[]> c_str(new char[str.size() + 1]);
70 str.copy(c_str.get(), str.size());
71 c_str[str.size()] = 0;
72 std::vector<char *> vec;
73 const string separator = "\n" + FLAGS_fst_field_separator;
74 SplitToVector(c_str.get(), separator.c_str(), &vec, true);
75 for (const char *c : vec) {
77 if (!ConvertSymbolToLabel(c, syms, unknown_label, allow_negative,
81 labels->push_back(label);
87 } // namespace internal
89 // Functor for compiling a string in an FST.
91 class StringCompiler {
93 using Label = typename Arc::Label;
94 using StateId = typename Arc::StateId;
95 using Weight = typename Arc::Weight;
97 explicit StringCompiler(StringTokenType token_type,
98 const SymbolTable *syms = nullptr,
99 Label unknown_label = kNoLabel,
100 bool allow_negative = false)
101 : token_type_(token_type),
103 unknown_label_(unknown_label),
104 allow_negative_(allow_negative) {}
106 // Compiles string into an FST.
108 bool operator()(const string &str, FST *fst) const {
109 std::vector<Label> labels;
110 if (!internal::ConvertStringToLabels(str, token_type_, syms_,
111 unknown_label_, allow_negative_,
115 Compile(labels, fst);
120 bool operator()(const string &str, FST *fst, Weight weight) const {
121 std::vector<Label> labels;
122 if (!internal::ConvertStringToLabels(str, token_type_, syms_,
123 unknown_label_, allow_negative_,
127 Compile(labels, fst, std::move(weight));
132 void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst,
133 Weight weight = Weight::One()) const {
135 while (fst->NumStates() <= labels.size()) fst->AddState();
136 for (StateId i = 0; i < labels.size(); ++i) {
137 fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
140 fst->SetFinal(labels.size(), std::move(weight));
143 template <class Unsigned>
144 void Compile(const std::vector<Label> &labels,
145 CompactStringFst<Arc, Unsigned> *fst) const {
146 fst->SetCompactElements(labels.begin(), labels.end());
149 template <class Unsigned>
150 void Compile(const std::vector<Label> &labels,
151 CompactWeightedStringFst<Arc, Unsigned> *fst,
152 const Weight &weight = Weight::One()) const {
153 std::vector<std::pair<Label, Weight>> compacts;
154 compacts.reserve(labels.size() + 1);
155 for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) {
156 compacts.emplace_back(labels[i], Weight::One());
158 compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight);
159 fst->SetCompactElements(compacts.begin(), compacts.end());
162 const StringTokenType token_type_;
163 const SymbolTable *syms_; // Symbol table (used when token type is symbol).
164 const Label unknown_label_; // Label for token missing from symbol table.
165 const bool allow_negative_; // Negative labels allowed?
167 StringCompiler(const StringCompiler &) = delete;
168 StringCompiler &operator=(const StringCompiler &) = delete;
171 // Functor for printing a string FST as a string.
173 class StringPrinter {
175 using Label = typename Arc::Label;
176 using StateId = typename Arc::StateId;
177 using Weight = typename Arc::Weight;
179 explicit StringPrinter(StringTokenType token_type,
180 const SymbolTable *syms = nullptr)
181 : token_type_(token_type), syms_(syms) {}
183 // Converts the FST into a string.
184 bool operator()(const Fst<Arc> &fst, string *result) {
185 if (!FstToLabels(fst)) {
186 VLOG(1) << "StringPrinter::operator(): FST is not a string";
190 if (token_type_ == StringTokenType::SYMBOL) {
191 std::stringstream sstrm;
192 for (size_t i = 0; i < labels_.size(); ++i) {
193 if (i) sstrm << *(FLAGS_fst_field_separator.rbegin());
194 if (!PrintLabel(labels_[i], sstrm)) return false;
196 *result = sstrm.str();
197 } else if (token_type_ == StringTokenType::BYTE) {
198 result->reserve(labels_.size());
199 for (size_t i = 0; i < labels_.size(); ++i) result->push_back(labels_[i]);
200 } else if (token_type_ == StringTokenType::UTF8) {
201 return LabelsToUTF8String(labels_, result);
203 VLOG(1) << "StringPrinter::operator(): Unknown token type: "
211 bool FstToLabels(const Fst<Arc> &fst) {
213 auto s = fst.Start();
214 if (s == kNoStateId) {
215 VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
219 while (fst.Final(s) == Weight::Zero()) {
220 ArcIterator<Fst<Arc>> aiter(fst, s);
222 VLOG(2) << "StringPrinter::FstToLabels: String FST traversal does "
223 << "not reach final state";
226 const auto &arc = aiter.Value();
227 labels_.push_back(arc.olabel);
229 if (s == kNoStateId) {
230 VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid state";
235 VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
236 << "outgoing arcs found";
243 bool PrintLabel(Label label, std::ostream &ostrm) {
245 const auto symbol = syms_->Find(label);
247 VLOG(2) << "StringPrinter::PrintLabel: Integer " << label << " is not "
248 << "mapped to any textual symbol, symbol table = "
259 const StringTokenType token_type_;
260 const SymbolTable *syms_; // Symbol table (used when token type is symbol).
261 std::vector<Label> labels_; // Input FST labels.
263 StringPrinter(const StringPrinter &) = delete;
264 StringPrinter &operator=(const StringPrinter &) = delete;
269 #endif // FST_LIB_STRING_H_