66e21861c0d3ac188e1bec79ad95a436d5518c69
[platform/upstream/openfst.git] / src / include / fst / script / draw-impl.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to draw a binary FST by producing a text file in dot format, a helper
5 // class to fstdraw.cc.
6
7 #ifndef FST_SCRIPT_DRAW_IMPL_H_
8 #define FST_SCRIPT_DRAW_IMPL_H_
9
10 #include <ostream>
11 #include <sstream>
12 #include <string>
13
14 #include <fst/fst.h>
15 #include <fst/util.h>
16 #include <fst/script/fst-class.h>
17
18 namespace fst {
19
20 // Print a binary FST in GraphViz textual format (helper class for fstdraw.cc).
21 // WARNING: Stand-alone use not recommend.
22 template <class Arc>
23 class FstDrawer {
24  public:
25   using Label = typename Arc::Label;
26   using StateId = typename Arc::StateId;
27   using Weight = typename Arc::Weight;
28
29   FstDrawer(const Fst<Arc> &fst, const SymbolTable *isyms,
30             const SymbolTable *osyms, const SymbolTable *ssyms, bool accep,
31             const string &title, float width, float height, bool portrait,
32             bool vertical, float ranksep, float nodesep, int fontsize,
33             int precision, const string &float_format, bool show_weight_one)
34       : fst_(fst),
35         isyms_(isyms),
36         osyms_(osyms),
37         ssyms_(ssyms),
38         accep_(accep && fst.Properties(kAcceptor, true)),
39         ostrm_(nullptr),
40         title_(title),
41         width_(width),
42         height_(height),
43         portrait_(portrait),
44         vertical_(vertical),
45         ranksep_(ranksep),
46         nodesep_(nodesep),
47         fontsize_(fontsize),
48         precision_(precision),
49         float_format_(float_format),
50         show_weight_one_(show_weight_one) {}
51
52   // Draw Fst to an output buffer (or stdout if buf = 0)
53   void Draw(std::ostream *strm, const string &dest) {
54     ostrm_ = strm;
55     ostrm_->precision(precision_);
56     if (float_format_ == "e")
57         ostrm_->setf(std::ios_base::scientific, std::ios_base::floatfield);
58     if (float_format_ == "f")
59         ostrm_->setf(std::ios_base::fixed, std::ios_base::floatfield);
60     // O.w. defaults to "g" per standard lib.
61     dest_ = dest;
62     StateId start = fst_.Start();
63     if (start == kNoStateId) return;
64     PrintString("digraph FST {\n");
65     if (vertical_) {
66       PrintString("rankdir = BT;\n");
67     } else {
68       PrintString("rankdir = LR;\n");
69     }
70     PrintString("size = \"");
71     Print(width_);
72     PrintString(",");
73     Print(height_);
74     PrintString("\";\n");
75     if (!dest_.empty()) PrintString("label = \"" + title_ + "\";\n");
76     PrintString("center = 1;\n");
77     if (portrait_) {
78       PrintString("orientation = Portrait;\n");
79     } else {
80       PrintString("orientation = Landscape;\n");
81     }
82     PrintString("ranksep = \"");
83     Print(ranksep_);
84     PrintString("\";\n");
85     PrintString("nodesep = \"");
86     Print(nodesep_);
87     PrintString("\";\n");
88     // Initial state first.
89     DrawState(start);
90     for (StateIterator<Fst<Arc>> siter(fst_); !siter.Done(); siter.Next()) {
91       const auto s = siter.Value();
92       if (s != start) DrawState(s);
93     }
94     PrintString("}\n");
95   }
96
97  private:
98   // Maximum line length in text file.
99   static const int kLineLen = 8096;
100
101   void PrintString(const string &str) const { *ostrm_ << str; }
102
103   // Escapes backslash and double quote if these occur in the string. Dot will
104   // not deal gracefully with these if they are not escaped.
105   inline void EscapeChars(const string &str, string *ns) const {
106     const char *c = str.c_str();
107     while (*c) {
108       if (*c == '\\' || *c == '"') ns->push_back('\\');
109       ns->push_back(*c);
110       ++c;
111     }
112   }
113
114   void PrintId(StateId id, const SymbolTable *syms, const char *name) const {
115     if (syms) {
116       auto symbol = syms->Find(id);
117       if (symbol == "") {
118         FSTERROR() << "FstDrawer: Integer " << id
119                    << " is not mapped to any textual symbol"
120                    << ", symbol table = " << syms->Name()
121                    << ", destination = " << dest_;
122         symbol = "?";
123       }
124       string nsymbol;
125       EscapeChars(symbol, &nsymbol);
126       PrintString(nsymbol);
127     } else {
128       PrintString(std::to_string(id));
129     }
130   }
131
132   void PrintStateId(StateId s) const { PrintId(s, ssyms_, "state ID"); }
133
134   void PrintILabel(Label label) const {
135     PrintId(label, isyms_, "arc input label");
136   }
137
138   void PrintOLabel(Label label) const {
139     PrintId(label, osyms_, "arc output label");
140   }
141
142   template <class T>
143   void Print(T t) const { *ostrm_ << t; }
144
145   void DrawState(StateId s) const {
146     Print(s);
147     PrintString(" [label = \"");
148     PrintStateId(s);
149     const auto weight = fst_.Final(s);
150     if (weight != Weight::Zero()) {
151       if (show_weight_one_ || (weight != Weight::One())) {
152         PrintString("/");
153         Print(weight);
154       }
155       PrintString("\", shape = doublecircle,");
156     } else {
157       PrintString("\", shape = circle,");
158     }
159     if (s == fst_.Start()) {
160       PrintString(" style = bold,");
161     } else {
162       PrintString(" style = solid,");
163     }
164     PrintString(" fontsize = ");
165     Print(fontsize_);
166     PrintString("]\n");
167     for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
168       const auto &arc = aiter.Value();
169       PrintString("\t");
170       Print(s);
171       PrintString(" -> ");
172       Print(arc.nextstate);
173       PrintString(" [label = \"");
174       PrintILabel(arc.ilabel);
175       if (!accep_) {
176         PrintString(":");
177         PrintOLabel(arc.olabel);
178       }
179       if (show_weight_one_ || (arc.weight != Weight::One())) {
180         PrintString("/");
181         Print(arc.weight);
182       }
183       PrintString("\", fontsize = ");
184       Print(fontsize_);
185       PrintString("];\n");
186     }
187   }
188
189   const Fst<Arc> &fst_;
190   const SymbolTable *isyms_;  // ilabel symbol table.
191   const SymbolTable *osyms_;  // olabel symbol table.
192   const SymbolTable *ssyms_;  // slabel symbol table.
193   bool accep_;                // Print as acceptor when possible.
194   std::ostream *ostrm_;       // Drawn FST destination.
195   string dest_;               // Drawn FST destination name.
196
197   string title_;
198   float width_;
199   float height_;
200   bool portrait_;
201   bool vertical_;
202   float ranksep_;
203   float nodesep_;
204   int fontsize_;
205   int precision_;
206   string float_format_;
207   bool show_weight_one_;
208
209   FstDrawer(const FstDrawer &) = delete;
210   FstDrawer &operator=(const FstDrawer &) = delete;
211 };
212
213 }  // namespace fst
214
215 #endif  // FST_SCRIPT_DRAW_IMPL_H_