1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Utilities to convert strings into FSTs.
6 #ifndef FST_LIB_STRING_H_
7 #define FST_LIB_STRING_H_
15 #include <fst/compact-fst.h>
17 #include <fst/mutable-fst.h>
21 DECLARE_string(fst_field_separator);
25 // This will eventually replace StringCompiler<Arc>::TokenType and
26 // StringPrinter<Arc>::TokenType.
27 enum StringTokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
31 template <class Label>
32 bool ConvertSymbolToLabel(const char *str, const SymbolTable *syms,
33 Label unknown_label, bool allow_negative,
38 if ((n == -1) && (unknown_label != kNoLabel)) n = unknown_label;
39 if (n == -1 || (!allow_negative && n < 0)) {
40 VLOG(1) << "ConvertSymbolToLabel: Symbol \"" << str
41 << "\" is not mapped to any integer label, symbol table = "
47 n = strtoll(str, &p, 10);
48 if (p < str + strlen(str) || (!allow_negative && n < 0)) {
49 VLOG(1) << "ConvertSymbolToLabel: Bad label integer "
50 << "= \"" << str << "\"";
58 template <class Label>
59 bool ConvertStringToLabels(const string &str, StringTokenType token_type,
60 const SymbolTable *syms, Label unknown_label,
61 bool allow_negative, std::vector<Label> *labels) {
63 if (token_type == StringTokenType::BYTE) {
64 for (const char c : str) labels->push_back(c);
65 } else if (token_type == StringTokenType::UTF8) {
66 return UTF8StringToLabels(str, labels);
68 auto *c_str = new char[str.size() + 1];
69 str.copy(c_str, str.size());
70 c_str[str.size()] = 0;
71 std::vector<char *> vec;
72 const string separator = "\n" + FLAGS_fst_field_separator;
73 SplitToVector(c_str, separator.c_str(), &vec, true);
74 for (const char *c : vec) {
76 if (!ConvertSymbolToLabel(c, syms, unknown_label, allow_negative,
81 labels->push_back(label);
88 } // namespace internal
90 // Functor for compiling a string in an FST.
92 class StringCompiler {
94 using Label = typename Arc::Label;
95 using StateId = typename Arc::StateId;
96 using Weight = typename Arc::Weight;
98 explicit StringCompiler(StringTokenType token_type,
99 const SymbolTable *syms = nullptr,
100 Label unknown_label = kNoLabel,
101 bool allow_negative = false)
102 : token_type_(token_type),
104 unknown_label_(unknown_label),
105 allow_negative_(allow_negative) {}
107 enum OPENFST_DEPRECATED("Use fst::StringTokenType") TokenType {
113 OPENFST_DEPRECATED("Use fst::StringTokenType")
114 explicit StringCompiler(TokenType token_type,
115 const SymbolTable *syms = nullptr,
116 Label unknown_label = kNoLabel,
117 bool allow_negative = false)
118 : StringCompiler(static_cast<StringTokenType>(token_type), syms,
119 unknown_label, allow_negative) {}
121 // Compiles string into an FST.
124 bool operator()(const string &str, FST *fst) const {
125 std::vector<Label> labels;
126 if (!internal::ConvertStringToLabels(str, token_type_, syms_,
127 unknown_label_, allow_negative_,
131 Compile(labels, fst);
136 bool operator()(const string &str, FST *fst, Weight weight) const {
137 std::vector<Label> labels;
138 if (!internal::ConvertStringToLabels(str, token_type_, syms_,
139 unknown_label_, allow_negative_,
143 Compile(labels, fst, std::move(weight));
148 void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst,
149 Weight weight = Weight::One()) const {
151 while (fst->NumStates() <= labels.size()) fst->AddState();
152 for (StateId i = 0; i < labels.size(); ++i) {
153 fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
156 fst->SetFinal(labels.size(), std::move(weight));
159 template <class Unsigned>
160 void Compile(const std::vector<Label> &labels,
161 CompactStringFst<Arc, Unsigned> *fst) const {
162 fst->SetCompactElements(labels.begin(), labels.end());
165 template <class Unsigned>
166 void Compile(const std::vector<Label> &labels,
167 CompactWeightedStringFst<Arc, Unsigned> *fst,
168 const Weight &weight = Weight::One()) const {
169 std::vector<std::pair<Label, Weight>> compacts;
170 compacts.reserve(labels.size() + 1);
171 for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) {
172 compacts.emplace_back(labels[i], Weight::One());
174 compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight);
175 fst->SetCompactElements(compacts.begin(), compacts.end());
178 const StringTokenType token_type_;
179 const SymbolTable *syms_; // Symbol table (used when token type is symbol).
180 const Label unknown_label_; // Label for token missing from symbol table.
181 const bool allow_negative_; // Negative labels allowed?
183 StringCompiler(const StringCompiler &) = delete;
184 StringCompiler &operator=(const StringCompiler &) = delete;
187 // Functor for printing a string FST as a string.
189 class StringPrinter {
191 using Label = typename Arc::Label;
192 using StateId = typename Arc::StateId;
193 using Weight = typename Arc::Weight;
195 explicit StringPrinter(StringTokenType token_type,
196 const SymbolTable *syms = nullptr)
197 : token_type_(token_type), syms_(syms) {}
199 enum OPENFST_DEPRECATED("Use fst::StringTokenType") TokenType {
205 OPENFST_DEPRECATED("Use fst::StringTokenType")
206 explicit StringPrinter(TokenType token_type,
207 const SymbolTable *syms = nullptr)
208 : StringPrinter(static_cast<StringTokenType>(token_type), syms) {}
210 // Converts the FST into a string.
211 bool operator()(const Fst<Arc> &fst, string *result) {
212 if (!FstToLabels(fst)) {
213 VLOG(1) << "StringPrinter::operator(): FST is not a string";
217 if (token_type_ == StringTokenType::SYMBOL) {
218 std::stringstream sstrm;
219 for (size_t i = 0; i < labels_.size(); ++i) {
220 if (i) sstrm << *(FLAGS_fst_field_separator.rbegin());
221 if (!PrintLabel(labels_[i], sstrm)) return false;
223 *result = sstrm.str();
224 } else if (token_type_ == StringTokenType::BYTE) {
225 result->reserve(labels_.size());
226 for (size_t i = 0; i < labels_.size(); ++i) result->push_back(labels_[i]);
227 } else if (token_type_ == StringTokenType::UTF8) {
228 return LabelsToUTF8String(labels_, result);
230 VLOG(1) << "StringPrinter::operator(): Unknown token type: "
238 bool FstToLabels(const Fst<Arc> &fst) {
240 auto s = fst.Start();
241 if (s == kNoStateId) {
242 VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
246 while (fst.Final(s) == Weight::Zero()) {
247 ArcIterator<Fst<Arc>> aiter(fst, s);
249 VLOG(2) << "StringPrinter::FstToLabels: String FST traversal does "
250 << "not reach final state";
253 const auto &arc = aiter.Value();
254 labels_.push_back(arc.olabel);
256 if (s == kNoStateId) {
257 VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid state";
262 VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
263 << "outgoing arcs found";
270 bool PrintLabel(Label label, std::ostream &ostrm) {
272 const auto symbol = syms_->Find(label);
274 VLOG(2) << "StringPrinter::PrintLabel: Integer " << label << " is not "
275 << "mapped to any textual symbol, symbol table = "
286 const StringTokenType token_type_;
287 const SymbolTable *syms_; // Symbol table (used when token type is symbol).
288 std::vector<Label> labels_; // Input FST labels.
290 StringPrinter(const StringPrinter &) = delete;
291 StringPrinter &operator=(const StringPrinter &) = delete;
296 #endif // FST_LIB_STRING_H_