Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / script / compile-impl.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to to compile a binary FST from textual input.
5
6 #ifndef FST_SCRIPT_COMPILE_IMPL_H_
7 #define FST_SCRIPT_COMPILE_IMPL_H_
8
9 #include <iostream>
10 #include <memory>
11 #include <sstream>
12 #include <string>
13 #include <vector>
14
15 #include <fst/fst.h>
16 #include <fst/util.h>
17 #include <fst/vector-fst.h>
18 #include <unordered_map>
19
20 DECLARE_string(fst_field_separator);
21
22 namespace fst {
23
24 // Compile a binary Fst from textual input, helper class for fstcompile.cc
25 // WARNING: Stand-alone use of this class not recommended, most code should
26 // read/write using the binary format which is much more efficient.
27 template <class Arc>
28 class FstCompiler {
29  public:
30   using Label = typename Arc::Label;
31   using StateId = typename Arc::StateId;
32   using Weight = typename Arc::Weight;
33
34   // WARNING: use of negative labels not recommended as it may cause conflicts.
35   // If add_symbols_ is true, then the symbols will be dynamically added to the
36   // symbol tables. This is only useful if you set the (i/o)keep flag to attach
37   // the final symbol table, or use the accessors. (The input symbol tables are
38   // const and therefore not changed.)
39   FstCompiler(std::istream &istrm, const string &source,  // NOLINT
40               const SymbolTable *isyms, const SymbolTable *osyms,
41               const SymbolTable *ssyms, bool accep, bool ikeep,
42               bool okeep, bool nkeep, bool allow_negative_labels = false) {
43     std::unique_ptr<SymbolTable> misyms(isyms ? isyms->Copy() : nullptr);
44     std::unique_ptr<SymbolTable> mosyms(osyms ? osyms->Copy() : nullptr);
45     std::unique_ptr<SymbolTable> mssyms(ssyms ? ssyms->Copy() : nullptr);
46     Init(istrm, source, misyms.get(), mosyms.get(), mssyms.get(), accep,
47          ikeep, okeep, nkeep, allow_negative_labels, false);
48   }
49
50   FstCompiler(std::istream &istrm, const string &source,  // NOLINT
51               SymbolTable *isyms, SymbolTable *osyms, SymbolTable *ssyms,
52               bool accep, bool ikeep, bool okeep, bool nkeep,
53               bool allow_negative_labels, bool add_symbols) {
54     Init(istrm, source, isyms, osyms, ssyms, accep, ikeep, okeep, nkeep,
55          allow_negative_labels, add_symbols);
56   }
57
58   void Init(std::istream &istrm, const string &source,  // NOLINT
59             SymbolTable *isyms, SymbolTable *osyms, SymbolTable *ssyms,
60             bool accep, bool ikeep, bool okeep, bool nkeep,
61             bool allow_negative_labels, bool add_symbols) {
62     nline_ = 0;
63     source_ = source;
64     isyms_ = isyms;
65     osyms_ = osyms;
66     ssyms_ = ssyms;
67     nstates_ = 0;
68     keep_state_numbering_ = nkeep;
69     allow_negative_labels_ = allow_negative_labels;
70     add_symbols_ = add_symbols;
71     bool start_state_populated = false;
72     char line[kLineLen];
73     const string separator = FLAGS_fst_field_separator + "\n";
74     while (istrm.getline(line, kLineLen)) {
75       ++nline_;
76       std::vector<char *> col;
77       SplitToVector(line, separator.c_str(), &col, true);
78       if (col.empty() || col[0][0] == '\0')
79         continue;
80       if (col.size() > 5 || (col.size() > 4 && accep) ||
81           (col.size() == 3 && !accep)) {
82         FSTERROR() << "FstCompiler: Bad number of columns, source = " << source_
83                    << ", line = " << nline_;
84         fst_.SetProperties(kError, kError);
85         return;
86       }
87       StateId s = StrToStateId(col[0]);
88       while (s >= fst_.NumStates()) fst_.AddState();
89       if (!start_state_populated) {
90         fst_.SetStart(s);
91         start_state_populated = true;
92       }
93
94       Arc arc;
95       StateId d = s;
96       switch (col.size()) {
97         case 1:
98           fst_.SetFinal(s, Weight::One());
99           break;
100         case 2:
101           fst_.SetFinal(s, StrToWeight(col[1], true));
102           break;
103         case 3:
104           arc.nextstate = d = StrToStateId(col[1]);
105           arc.ilabel = StrToILabel(col[2]);
106           arc.olabel = arc.ilabel;
107           arc.weight = Weight::One();
108           fst_.AddArc(s, arc);
109           break;
110         case 4:
111           arc.nextstate = d = StrToStateId(col[1]);
112           arc.ilabel = StrToILabel(col[2]);
113           if (accep) {
114             arc.olabel = arc.ilabel;
115             arc.weight = StrToWeight(col[3], true);
116           } else {
117             arc.olabel = StrToOLabel(col[3]);
118             arc.weight = Weight::One();
119           }
120           fst_.AddArc(s, arc);
121           break;
122         case 5:
123           arc.nextstate = d = StrToStateId(col[1]);
124           arc.ilabel = StrToILabel(col[2]);
125           arc.olabel = StrToOLabel(col[3]);
126           arc.weight = StrToWeight(col[4], true);
127           fst_.AddArc(s, arc);
128       }
129       while (d >= fst_.NumStates()) fst_.AddState();
130     }
131     if (ikeep) fst_.SetInputSymbols(isyms);
132     if (okeep) fst_.SetOutputSymbols(osyms);
133   }
134
135   const VectorFst<Arc> &Fst() const { return fst_; }
136
137  private:
138   // Maximum line length in text file.
139   static constexpr int kLineLen = 8096;
140
141   StateId StrToId(const char *s, SymbolTable *syms, const char *name,
142                   bool allow_negative = false) const {
143     StateId n = 0;
144     if (syms) {
145       n = (add_symbols_) ? syms->AddSymbol(s) : syms->Find(s);
146       if (n == -1 || (!allow_negative && n < 0)) {
147         FSTERROR() << "FstCompiler: Symbol \"" << s
148                    << "\" is not mapped to any integer " << name
149                    << ", symbol table = " << syms->Name()
150                    << ", source = " << source_ << ", line = " << nline_;
151         fst_.SetProperties(kError, kError);
152       }
153     } else {
154       char *p;
155       n = strtoll(s, &p, 10);
156       if (p < s + strlen(s) || (!allow_negative && n < 0)) {
157         FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
158                    << "\", source = " << source_ << ", line = " << nline_;
159         fst_.SetProperties(kError, kError);
160       }
161     }
162     return n;
163   }
164
165   StateId StrToStateId(const char *s) {
166     StateId n = StrToId(s, ssyms_, "state ID");
167     if (keep_state_numbering_) return n;
168     // Remaps state IDs to make dense set.
169     const auto it = states_.find(n);
170     if (it == states_.end()) {
171       states_[n] = nstates_;
172       return nstates_++;
173     } else {
174       return it->second;
175     }
176   }
177
178   StateId StrToILabel(const char *s) const {
179     return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_);
180   }
181
182   StateId StrToOLabel(const char *s) const {
183     return StrToId(s, osyms_, "arc olabel", allow_negative_labels_);
184   }
185
186   Weight StrToWeight(const char *s, bool allow_zero) const {
187     Weight w;
188     std::istringstream strm(s);
189     strm >> w;
190     if (!strm || (!allow_zero && w == Weight::Zero())) {
191       FSTERROR() << "FstCompiler: Bad weight = \"" << s
192                  << "\", source = " << source_ << ", line = " << nline_;
193       fst_.SetProperties(kError, kError);
194       w = Weight::NoWeight();
195     }
196     return w;
197   }
198
199   mutable VectorFst<Arc> fst_;
200   size_t nline_;
201   string source_;       // Text FST source name.
202   SymbolTable *isyms_;  // ilabel symbol table (not owned).
203   SymbolTable *osyms_;  // olabel symbol table (not owned).
204   SymbolTable *ssyms_;  // slabel symbol table (not owned).
205   std::unordered_map<StateId, StateId> states_;  // State ID map.
206   StateId nstates_;                              // Number of seen states.
207   bool keep_state_numbering_;
208   bool allow_negative_labels_;  // Not recommended; may cause conflicts.
209   bool add_symbols_;            // Add to symbol tables on-the fly.
210
211   FstCompiler(const FstCompiler &) = delete;
212   FstCompiler &operator=(const FstCompiler &) = delete;
213 };
214
215 }  // namespace fst
216
217 #endif  // FST_SCRIPT_COMPILE_IMPL_H_