1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Utilities to convert strings into FSTs.
14 #include <fst/flags.h>
17 #include <fst/compact-fst.h>
19 #include <fst/mutable-fst.h>
23 DECLARE_string(fst_field_separator);
27 // This will eventually replace StringCompiler<Arc>::TokenType and
28 // StringPrinter<Arc>::TokenType.
29 enum StringTokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
33 template <class Label>
34 bool ConvertSymbolToLabel(const char *str, const SymbolTable *syms,
35 Label unknown_label, bool allow_negative,
40 if ((n == -1) && (unknown_label != kNoLabel)) n = unknown_label;
41 if (n == -1 || (!allow_negative && n < 0)) {
42 VLOG(1) << "ConvertSymbolToLabel: Symbol \"" << str
43 << "\" is not mapped to any integer label, symbol table = "
49 n = strtoll(str, &p, 10);
50 if (p < str + strlen(str) || (!allow_negative && n < 0)) {
51 VLOG(1) << "ConvertSymbolToLabel: Bad label integer "
52 << "= \"" << str << "\"";
60 template <class Label>
61 bool ConvertStringToLabels(const string &str, StringTokenType token_type,
62 const SymbolTable *syms, Label unknown_label,
63 bool allow_negative, std::vector<Label> *labels) {
65 if (token_type == StringTokenType::BYTE) {
66 for (const char c : str) labels->push_back(c);
67 } else if (token_type == StringTokenType::UTF8) {
68 return UTF8StringToLabels(str, labels);
70 std::unique_ptr<char[]> c_str(new char[str.size() + 1]);
71 str.copy(c_str.get(), str.size());
72 c_str[str.size()] = 0;
73 std::vector<char *> vec;
74 const string separator = "\n" + FLAGS_fst_field_separator;
75 SplitToVector(c_str.get(), separator.c_str(), &vec, true);
76 for (const char *c : vec) {
78 if (!ConvertSymbolToLabel(c, syms, unknown_label, allow_negative,
82 labels->push_back(label);
88 } // namespace internal
90 // Functor for compiling a string in an FST.
92 class StringCompiler {
94 using Label = typename Arc::Label;
95 using StateId = typename Arc::StateId;
96 using Weight = typename Arc::Weight;
98 explicit StringCompiler(StringTokenType token_type,
99 const SymbolTable *syms = nullptr,
100 Label unknown_label = kNoLabel,
101 bool allow_negative = false)
102 : token_type_(token_type),
104 unknown_label_(unknown_label),
105 allow_negative_(allow_negative) {}
107 // Compiles string into an FST.
109 bool operator()(const string &str, FST *fst) const {
110 std::vector<Label> labels;
111 if (!internal::ConvertStringToLabels(str, token_type_, syms_,
112 unknown_label_, allow_negative_,
116 Compile(labels, fst);
121 bool operator()(const string &str, FST *fst, Weight weight) const {
122 std::vector<Label> labels;
123 if (!internal::ConvertStringToLabels(str, token_type_, syms_,
124 unknown_label_, allow_negative_,
128 Compile(labels, fst, std::move(weight));
133 void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst,
134 Weight weight = Weight::One()) const {
136 while (fst->NumStates() <= labels.size()) fst->AddState();
137 for (StateId i = 0; i < labels.size(); ++i) {
138 fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
141 fst->SetFinal(labels.size(), std::move(weight));
144 template <class Unsigned>
145 void Compile(const std::vector<Label> &labels,
146 CompactStringFst<Arc, Unsigned> *fst) const {
147 fst->SetCompactElements(labels.begin(), labels.end());
150 template <class Unsigned>
151 void Compile(const std::vector<Label> &labels,
152 CompactWeightedStringFst<Arc, Unsigned> *fst,
153 const Weight &weight = Weight::One()) const {
154 std::vector<std::pair<Label, Weight>> compacts;
155 compacts.reserve(labels.size() + 1);
156 for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) {
157 compacts.emplace_back(labels[i], Weight::One());
159 compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight);
160 fst->SetCompactElements(compacts.begin(), compacts.end());
163 const StringTokenType token_type_;
164 const SymbolTable *syms_; // Symbol table (used when token type is symbol).
165 const Label unknown_label_; // Label for token missing from symbol table.
166 const bool allow_negative_; // Negative labels allowed?
168 StringCompiler(const StringCompiler &) = delete;
169 StringCompiler &operator=(const StringCompiler &) = delete;
172 // Functor for printing a string FST as a string.
174 class StringPrinter {
176 using Label = typename Arc::Label;
177 using StateId = typename Arc::StateId;
178 using Weight = typename Arc::Weight;
180 explicit StringPrinter(StringTokenType token_type,
181 const SymbolTable *syms = nullptr)
182 : token_type_(token_type), syms_(syms) {}
184 // Converts the FST into a string.
185 bool operator()(const Fst<Arc> &fst, string *result) {
186 if (!FstToLabels(fst)) {
187 VLOG(1) << "StringPrinter::operator(): FST is not a string";
191 if (token_type_ == StringTokenType::SYMBOL) {
192 std::stringstream sstrm;
193 for (size_t i = 0; i < labels_.size(); ++i) {
194 if (i) sstrm << *(FLAGS_fst_field_separator.rbegin());
195 if (!PrintLabel(labels_[i], sstrm)) return false;
197 *result = sstrm.str();
198 } else if (token_type_ == StringTokenType::BYTE) {
199 result->reserve(labels_.size());
200 for (size_t i = 0; i < labels_.size(); ++i) result->push_back(labels_[i]);
201 } else if (token_type_ == StringTokenType::UTF8) {
202 return LabelsToUTF8String(labels_, result);
204 VLOG(1) << "StringPrinter::operator(): Unknown token type: "
212 bool FstToLabels(const Fst<Arc> &fst) {
214 auto s = fst.Start();
215 if (s == kNoStateId) {
216 VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
220 while (fst.Final(s) == Weight::Zero()) {
221 ArcIterator<Fst<Arc>> aiter(fst, s);
223 VLOG(2) << "StringPrinter::FstToLabels: String FST traversal does "
224 << "not reach final state";
227 const auto &arc = aiter.Value();
228 labels_.push_back(arc.olabel);
230 if (s == kNoStateId) {
231 VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid state";
236 VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
237 << "outgoing arcs found";
244 bool PrintLabel(Label label, std::ostream &ostrm) {
246 const auto symbol = syms_->Find(label);
248 VLOG(2) << "StringPrinter::PrintLabel: Integer " << label << " is not "
249 << "mapped to any textual symbol, symbol table = "
260 const StringTokenType token_type_;
261 const SymbolTable *syms_; // Symbol table (used when token type is symbol).
262 std::vector<Label> labels_; // Input FST labels.
264 StringPrinter(const StringPrinter &) = delete;
265 StringPrinter &operator=(const StringPrinter &) = delete;
270 #endif // FST_STRING_H_