Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / script / draw-impl.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to draw a binary FST by producing a text file in dot format, a helper
5 // class to fstdraw.cc.
6
7 #ifndef FST_SCRIPT_DRAW_IMPL_H_
8 #define FST_SCRIPT_DRAW_IMPL_H_
9
10 #include <ostream>
11 #include <sstream>
12 #include <string>
13
14 #include <fst/fst.h>
15 #include <fst/util.h>
16 #include <fst/script/fst-class.h>
17
18 namespace fst {
19
20 // Print a binary FST in GraphViz textual format (helper class for fstdraw.cc).
21 // WARNING: Stand-alone use not recommend.
22 template <class Arc>
23 class FstDrawer {
24  public:
25   using Label = typename Arc::Label;
26   using StateId = typename Arc::StateId;
27   using Weight = typename Arc::Weight;
28
29   FstDrawer(const Fst<Arc> &fst, const SymbolTable *isyms,
30             const SymbolTable *osyms, const SymbolTable *ssyms, bool accep,
31             const string &title, float width, float height, bool portrait,
32             bool vertical, float ranksep, float nodesep, int fontsize,
33             int precision, const string &float_format, bool show_weight_one)
34       : fst_(fst),
35         isyms_(isyms),
36         osyms_(osyms),
37         ssyms_(ssyms),
38         accep_(accep && fst.Properties(kAcceptor, true)),
39         ostrm_(nullptr),
40         title_(title),
41         width_(width),
42         height_(height),
43         portrait_(portrait),
44         vertical_(vertical),
45         ranksep_(ranksep),
46         nodesep_(nodesep),
47         fontsize_(fontsize),
48         precision_(precision),
49         float_format_(float_format),
50         show_weight_one_(show_weight_one) {}
51
52   // Draw Fst to an output buffer (or stdout if buf = 0)
53   void Draw(std::ostream *strm, const string &dest) {
54     ostrm_ = strm;
55     SetStreamState(ostrm_);
56     dest_ = dest;
57     StateId start = fst_.Start();
58     if (start == kNoStateId) return;
59     PrintString("digraph FST {\n");
60     if (vertical_) {
61       PrintString("rankdir = BT;\n");
62     } else {
63       PrintString("rankdir = LR;\n");
64     }
65     PrintString("size = \"");
66     Print(width_);
67     PrintString(",");
68     Print(height_);
69     PrintString("\";\n");
70     if (!dest_.empty()) PrintString("label = \"" + title_ + "\";\n");
71     PrintString("center = 1;\n");
72     if (portrait_) {
73       PrintString("orientation = Portrait;\n");
74     } else {
75       PrintString("orientation = Landscape;\n");
76     }
77     PrintString("ranksep = \"");
78     Print(ranksep_);
79     PrintString("\";\n");
80     PrintString("nodesep = \"");
81     Print(nodesep_);
82     PrintString("\";\n");
83     // Initial state first.
84     DrawState(start);
85     for (StateIterator<Fst<Arc>> siter(fst_); !siter.Done(); siter.Next()) {
86       const auto s = siter.Value();
87       if (s != start) DrawState(s);
88     }
89     PrintString("}\n");
90   }
91
92  private:
93   void SetStreamState(std::ostream* strm) const {
94     strm->precision(precision_);
95     if (float_format_ == "e")
96         strm->setf(std::ios_base::scientific, std::ios_base::floatfield);
97     if (float_format_ == "f")
98         strm->setf(std::ios_base::fixed, std::ios_base::floatfield);
99     // O.w. defaults to "g" per standard lib.
100   }
101
102   void PrintString(const string &str) const { *ostrm_ << str; }
103
104   // Escapes backslash and double quote if these occur in the string. Dot will
105   // not deal gracefully with these if they are not escaped.
106   static string Escape(const string &str) {
107     string ns;
108     for (char c : str) {
109       if (c == '\\' || c == '"') ns.push_back('\\');
110       ns.push_back(c);
111     }
112     return ns;
113   }
114
115   void PrintId(StateId id, const SymbolTable *syms, const char *name) const {
116     if (syms) {
117       auto symbol = syms->Find(id);
118       if (symbol.empty()) {
119         FSTERROR() << "FstDrawer: Integer " << id
120                    << " is not mapped to any textual symbol"
121                    << ", symbol table = " << syms->Name()
122                    << ", destination = " << dest_;
123         symbol = "?";
124       }
125       PrintString(Escape(symbol));
126     } else {
127       PrintString(std::to_string(id));
128     }
129   }
130
131   void PrintStateId(StateId s) const { PrintId(s, ssyms_, "state ID"); }
132
133   void PrintILabel(Label label) const {
134     PrintId(label, isyms_, "arc input label");
135   }
136
137   void PrintOLabel(Label label) const {
138     PrintId(label, osyms_, "arc output label");
139   }
140
141   void PrintWeight(Weight w) const {
142     // Weight may have double quote characters in it, so escape it.
143     PrintString(Escape(ToString(w)));
144   }
145
146   template <class T>
147   void Print(T t) const { *ostrm_ << t; }
148
149   template <class T>
150   string ToString(T t) const {
151     std::stringstream ss;
152     SetStreamState(&ss);
153     ss << t;
154     return ss.str();
155   }
156
157   void DrawState(StateId s) const {
158     Print(s);
159     PrintString(" [label = \"");
160     PrintStateId(s);
161     const auto weight = fst_.Final(s);
162     if (weight != Weight::Zero()) {
163       if (show_weight_one_ || (weight != Weight::One())) {
164         PrintString("/");
165         PrintWeight(weight);
166       }
167       PrintString("\", shape = doublecircle,");
168     } else {
169       PrintString("\", shape = circle,");
170     }
171     if (s == fst_.Start()) {
172       PrintString(" style = bold,");
173     } else {
174       PrintString(" style = solid,");
175     }
176     PrintString(" fontsize = ");
177     Print(fontsize_);
178     PrintString("]\n");
179     for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
180       const auto &arc = aiter.Value();
181       PrintString("\t");
182       Print(s);
183       PrintString(" -> ");
184       Print(arc.nextstate);
185       PrintString(" [label = \"");
186       PrintILabel(arc.ilabel);
187       if (!accep_) {
188         PrintString(":");
189         PrintOLabel(arc.olabel);
190       }
191       if (show_weight_one_ || (arc.weight != Weight::One())) {
192         PrintString("/");
193         PrintWeight(arc.weight);
194       }
195       PrintString("\", fontsize = ");
196       Print(fontsize_);
197       PrintString("];\n");
198     }
199   }
200
201   const Fst<Arc> &fst_;
202   const SymbolTable *isyms_;  // ilabel symbol table.
203   const SymbolTable *osyms_;  // olabel symbol table.
204   const SymbolTable *ssyms_;  // slabel symbol table.
205   bool accep_;                // Print as acceptor when possible.
206   std::ostream *ostrm_;       // Drawn FST destination.
207   string dest_;               // Drawn FST destination name.
208
209   string title_;
210   float width_;
211   float height_;
212   bool portrait_;
213   bool vertical_;
214   float ranksep_;
215   float nodesep_;
216   int fontsize_;
217   int precision_;
218   string float_format_;
219   bool show_weight_one_;
220
221   FstDrawer(const FstDrawer &) = delete;
222   FstDrawer &operator=(const FstDrawer &) = delete;
223 };
224
225 }  // namespace fst
226
227 #endif  // FST_SCRIPT_DRAW_IMPL_H_