Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / script / fst-class.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3
4 #ifndef FST_SCRIPT_FST_CLASS_H_
5 #define FST_SCRIPT_FST_CLASS_H_
6
7 #include <algorithm>
8 #include <limits>
9 #include <string>
10 #include <type_traits>
11
12 #include <fst/expanded-fst.h>
13 #include <fst/fst.h>
14 #include <fst/mutable-fst.h>
15 #include <fst/vector-fst.h>
16 #include <fst/script/arc-class.h>
17 #include <fst/script/weight-class.h>
18
19 // Classes to support "boxing" all existing types of FST arcs in a single
20 // FstClass which hides the arc types. This allows clients to load
21 // and work with FSTs without knowing the arc type. These classes are only
22 // recommended for use in high-level scripting applications. Most users should
23 // use the lower-level templated versions corresponding to these classes.
24
25 namespace fst {
26 namespace script {
27
28 // Abstract base class defining the set of functionalities implemented in all
29 // impls and passed through by all bases. Below FstClassBase the class
30 // hierarchy bifurcates; FstClassImplBase serves as the base class for all
31 // implementations (of which FstClassImpl is currently the only one) and
32 // FstClass serves as the base class for all interfaces.
33
34 class FstClassBase {
35  public:
36   virtual const string &ArcType() const = 0;
37   virtual WeightClass Final(int64) const = 0;
38   virtual const string &FstType() const = 0;
39   virtual const SymbolTable *InputSymbols() const = 0;
40   virtual size_t NumArcs(int64) const = 0;
41   virtual size_t NumInputEpsilons(int64) const = 0;
42   virtual size_t NumOutputEpsilons(int64) const = 0;
43   virtual const SymbolTable *OutputSymbols() const = 0;
44   virtual uint64 Properties(uint64, bool) const = 0;
45   virtual int64 Start() const = 0;
46   virtual const string &WeightType() const = 0;
47   virtual bool ValidStateId(int64) const = 0;
48   virtual bool Write(const string &) const = 0;
49   virtual bool Write(std::ostream &, const string &) const = 0;
50   virtual ~FstClassBase() {}
51 };
52
53 // Adds all the MutableFst methods.
54 class FstClassImplBase : public FstClassBase {
55  public:
56   virtual bool AddArc(int64, const ArcClass &) = 0;
57   virtual int64 AddState() = 0;
58   virtual FstClassImplBase *Copy() = 0;
59   virtual bool DeleteArcs(int64, size_t) = 0;
60   virtual bool DeleteArcs(int64) = 0;
61   virtual bool DeleteStates(const std::vector<int64> &) = 0;
62   virtual void DeleteStates() = 0;
63   virtual SymbolTable *MutableInputSymbols() = 0;
64   virtual SymbolTable *MutableOutputSymbols() = 0;
65   virtual int64 NumStates() const = 0;
66   virtual bool ReserveArcs(int64, size_t) = 0;
67   virtual void ReserveStates(int64) = 0;
68   virtual void SetInputSymbols(SymbolTable *) = 0;
69   virtual bool SetFinal(int64, const WeightClass &) = 0;
70   virtual void SetOutputSymbols(SymbolTable *) = 0;
71   virtual void SetProperties(uint64, uint64) = 0;
72   virtual bool SetStart(int64) = 0;
73   ~FstClassImplBase() override {}
74 };
75
76 // Containiner class wrapping an Fst<Arc>, hiding its arc type. Whether this
77 // Fst<Arc> pointer refers to a special kind of FST (e.g. a MutableFst) is
78 // known by the type of interface class that owns the pointer to this
79 // container.
80
81 template <class Arc>
82 class FstClassImpl : public FstClassImplBase {
83  public:
84   explicit FstClassImpl(Fst<Arc> *impl, bool should_own = false)
85       : impl_(should_own ? impl : impl->Copy()) {}
86
87   explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) {}
88
89   // Warning: calling this method casts the FST to a mutable FST.
90   bool AddArc(int64 s, const ArcClass &ac) final {
91     if (!ValidStateId(s)) return false;
92     // Note that we do not check that the destination state is valid, so users
93     // can add arcs before they add the corresponding states. Verify can be
94     // used to determine whether any arc has a nonexisting destination.
95     Arc arc(ac.ilabel, ac.olabel, *ac.weight.GetWeight<typename Arc::Weight>(),
96             ac.nextstate);
97     static_cast<MutableFst<Arc> *>(impl_.get())->AddArc(s, arc);
98     return true;
99   }
100
101   // Warning: calling this method casts the FST to a mutable FST.
102   int64 AddState() final {
103     return static_cast<MutableFst<Arc> *>(impl_.get())->AddState();
104   }
105
106   const string &ArcType() const final { return Arc::Type(); }
107
108   FstClassImpl *Copy() final { return new FstClassImpl<Arc>(impl_.get()); }
109
110   // Warning: calling this method casts the FST to a mutable FST.
111   bool DeleteArcs(int64 s, size_t n) final {
112     if (!ValidStateId(s)) return false;
113     static_cast<MutableFst<Arc> *>(impl_.get())->DeleteArcs(s, n);
114     return true;
115   }
116
117   // Warning: calling this method casts the FST to a mutable FST.
118   bool DeleteArcs(int64 s) final {
119     if (!ValidStateId(s)) return false;
120     static_cast<MutableFst<Arc> *>(impl_.get())->DeleteArcs(s);
121     return true;
122   }
123
124   // Warning: calling this method casts the FST to a mutable FST.
125   bool DeleteStates(const std::vector<int64> &dstates) final {
126     for (const auto &state : dstates)
127       if (!ValidStateId(state)) return false;
128     // Warning: calling this method with any integers beyond the precision of
129     // the underlying FST will result in truncation.
130     std::vector<typename Arc::StateId> typed_dstates(dstates.size());
131     std::copy(dstates.begin(), dstates.end(), typed_dstates.begin());
132     static_cast<MutableFst<Arc> *>(impl_.get())->DeleteStates(typed_dstates);
133     return true;
134   }
135
136   // Warning: calling this method casts the FST to a mutable FST.
137   void DeleteStates() final {
138     static_cast<MutableFst<Arc> *>(impl_.get())->DeleteStates();
139   }
140
141   WeightClass Final(int64 s) const final {
142     if (!ValidStateId(s)) return WeightClass::NoWeight(WeightType());
143     WeightClass w(impl_->Final(s));
144     return w;
145   }
146
147   const string &FstType() const final { return impl_->Type(); }
148
149   const SymbolTable *InputSymbols() const final {
150     return impl_->InputSymbols();
151   }
152
153   // Warning: calling this method casts the FST to a mutable FST.
154   SymbolTable *MutableInputSymbols() final {
155     return static_cast<MutableFst<Arc> *>(impl_.get())->MutableInputSymbols();
156   }
157
158   // Warning: calling this method casts the FST to a mutable FST.
159   SymbolTable *MutableOutputSymbols() final {
160     return static_cast<MutableFst<Arc> *>(impl_.get())->MutableOutputSymbols();
161   }
162
163   // Signals failure by returning size_t max.
164   size_t NumArcs(int64 s) const final {
165     return ValidStateId(s) ? impl_->NumArcs(s)
166                            : std::numeric_limits<size_t>::max();
167   }
168
169   // Signals failure by returning size_t max.
170   size_t NumInputEpsilons(int64 s) const final {
171     return ValidStateId(s) ? impl_->NumInputEpsilons(s)
172                            : std::numeric_limits<size_t>::max();
173   }
174
175   // Signals failure by returning size_t max.
176   size_t NumOutputEpsilons(int64 s) const final {
177     return ValidStateId(s) ? impl_->NumOutputEpsilons(s)
178                            : std::numeric_limits<size_t>::max();
179   }
180
181   // Warning: calling this method casts the FST to a mutable FST.
182   int64 NumStates() const final {
183     return static_cast<MutableFst<Arc> *>(impl_.get())->NumStates();
184   }
185
186   uint64 Properties(uint64 mask, bool test) const final {
187     return impl_->Properties(mask, test);
188   }
189
190   // Warning: calling this method casts the FST to a mutable FST.
191   bool ReserveArcs(int64 s, size_t n) final {
192     if (!ValidStateId(s)) return false;
193     static_cast<MutableFst<Arc> *>(impl_.get())->ReserveArcs(s, n);
194     return true;
195   }
196
197   // Warning: calling this method casts the FST to a mutable FST.
198   void ReserveStates(int64 s) final {
199     static_cast<MutableFst<Arc> *>(impl_.get())->ReserveStates(s);
200   }
201
202   const SymbolTable *OutputSymbols() const final {
203     return impl_->OutputSymbols();
204   }
205
206   // Warning: calling this method casts the FST to a mutable FST.
207   void SetInputSymbols(SymbolTable *isyms) final {
208     static_cast<MutableFst<Arc> *>(impl_.get())->SetInputSymbols(isyms);
209   }
210
211   // Warning: calling this method casts the FST to a mutable FST.
212   bool SetFinal(int64 s, const WeightClass &weight) final {
213     if (!ValidStateId(s)) return false;
214     static_cast<MutableFst<Arc> *>(impl_.get())
215         ->SetFinal(s, *weight.GetWeight<typename Arc::Weight>());
216     return true;
217   }
218
219   // Warning: calling this method casts the FST to a mutable FST.
220   void SetOutputSymbols(SymbolTable *osyms) final {
221     static_cast<MutableFst<Arc> *>(impl_.get())->SetOutputSymbols(osyms);
222   }
223
224   // Warning: calling this method casts the FST to a mutable FST.
225   void SetProperties(uint64 props, uint64 mask) final {
226     static_cast<MutableFst<Arc> *>(impl_.get())->SetProperties(props, mask);
227   }
228
229   // Warning: calling this method casts the FST to a mutable FST.
230   bool SetStart(int64 s) final {
231     if (!ValidStateId(s)) return false;
232     static_cast<MutableFst<Arc> *>(impl_.get())->SetStart(s);
233     return true;
234   }
235
236   int64 Start() const final { return impl_->Start(); }
237
238   bool ValidStateId(int64 s) const final {
239     // This cowardly refuses to count states if the FST is not yet expanded.
240     if (!Properties(kExpanded, true)) {
241       FSTERROR() << "Cannot get number of states for unexpanded FST";
242       return false;
243     }
244     // If the FST is already expanded, CountStates calls NumStates.
245     if (s < 0 || s >= CountStates(*impl_)) {
246       FSTERROR() << "State ID " << s << " not valid";
247       return false;
248     }
249     return true;
250   }
251
252   const string &WeightType() const final { return Arc::Weight::Type(); }
253
254   bool Write(const string &fname) const final { return impl_->Write(fname); }
255
256   bool Write(std::ostream &ostr, const string &fname) const final {
257     const FstWriteOptions opts(fname);
258     return impl_->Write(ostr, opts);
259   }
260
261   ~FstClassImpl() override {}
262
263   Fst<Arc> *GetImpl() const { return impl_.get(); }
264
265  private:
266   std::unique_ptr<Fst<Arc>> impl_;
267 };
268
269 // BASE CLASS DEFINITIONS
270
271 class MutableFstClass;
272
273 class FstClass : public FstClassBase {
274  public:
275   FstClass() : impl_(nullptr) {}
276
277   template <class Arc>
278   explicit FstClass(const Fst<Arc> &fst) : impl_(new FstClassImpl<Arc>(fst)) {}
279
280   FstClass(const FstClass &other)
281       : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {}
282
283   FstClass &operator=(const FstClass &other) {
284     impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy());
285     return *this;
286   }
287
288   WeightClass Final(int64 s) const final { return impl_->Final(s); }
289
290   const string &ArcType() const final { return impl_->ArcType(); }
291
292   const string &FstType() const final { return impl_->FstType(); }
293
294   const SymbolTable *InputSymbols() const final {
295     return impl_->InputSymbols();
296   }
297
298   size_t NumArcs(int64 s) const final { return impl_->NumArcs(s); }
299
300   size_t NumInputEpsilons(int64 s) const final {
301     return impl_->NumInputEpsilons(s);
302   }
303
304   size_t NumOutputEpsilons(int64 s) const final {
305     return impl_->NumOutputEpsilons(s);
306   }
307
308   const SymbolTable *OutputSymbols() const final {
309     return impl_->OutputSymbols();
310   }
311
312   uint64 Properties(uint64 mask, bool test) const final {
313     // Special handling for FSTs with a null impl.
314     if (!impl_) return kError & mask;
315     return impl_->Properties(mask, test);
316   }
317
318   static FstClass *Read(const string &fname);
319
320   static FstClass *Read(std::istream &istrm, const string &source);
321
322   int64 Start() const final { return impl_->Start(); }
323
324   bool ValidStateId(int64 s) const final { return impl_->ValidStateId(s); }
325
326   const string &WeightType() const final { return impl_->WeightType(); }
327
328   // Helper that logs an ERROR if the weight type of an FST and a WeightClass
329   // don't match.
330
331   bool WeightTypesMatch(const WeightClass &weight, const string &op_name) const;
332
333   bool Write(const string &fname) const final { return impl_->Write(fname); }
334
335   bool Write(std::ostream &ostr, const string &fname) const final {
336     return impl_->Write(ostr, fname);
337   }
338
339   ~FstClass() override {}
340
341   // These methods are required by IO registration.
342
343   template <class Arc>
344   static FstClassImplBase *Convert(const FstClass &other) {
345     FSTERROR() << "Doesn't make sense to convert any class to type FstClass";
346     return nullptr;
347   }
348
349   template <class Arc>
350   static FstClassImplBase *Create() {
351     FSTERROR() << "Doesn't make sense to create an FstClass with a "
352                << "particular arc type";
353     return nullptr;
354   }
355
356   template <class Arc>
357   const Fst<Arc> *GetFst() const {
358     if (Arc::Type() != ArcType()) {
359       return nullptr;
360     } else {
361       FstClassImpl<Arc> *typed_impl =
362           static_cast<FstClassImpl<Arc> *>(impl_.get());
363       return typed_impl->GetImpl();
364     }
365   }
366
367   template <class Arc>
368   static FstClass *Read(std::istream &stream, const FstReadOptions &opts) {
369     if (!opts.header) {
370       LOG(ERROR) << "FstClass::Read: Options header not specified";
371       return nullptr;
372     }
373     const FstHeader &hdr = *opts.header;
374     if (hdr.Properties() & kMutable) {
375       return ReadTypedFst<MutableFstClass, MutableFst<Arc>>(stream, opts);
376     } else {
377       return ReadTypedFst<FstClass, Fst<Arc>>(stream, opts);
378     }
379   }
380
381  protected:
382   explicit FstClass(FstClassImplBase *impl) : impl_(impl) {}
383
384   const FstClassImplBase *GetImpl() const { return impl_.get(); }
385
386   FstClassImplBase *GetImpl() { return impl_.get(); }
387
388   // Generic template method for reading an arc-templated FST of type
389   // UnderlyingT, and returning it wrapped as FstClassT, with appropriat
390   // error checking. Called from arc-templated Read() static methods.
391   template <class FstClassT, class UnderlyingT>
392   static FstClassT *ReadTypedFst(std::istream &stream,
393                                  const FstReadOptions &opts) {
394     std::unique_ptr<UnderlyingT> u(UnderlyingT::Read(stream, opts));
395     return u ? new FstClassT(*u) : nullptr;
396   }
397
398  private:
399   std::unique_ptr<FstClassImplBase> impl_;
400 };
401
402 // Specific types of FstClass with special properties
403
404 class MutableFstClass : public FstClass {
405  public:
406   bool AddArc(int64 s, const ArcClass &ac) {
407     if (!WeightTypesMatch(ac.weight, "AddArc")) return false;
408     return GetImpl()->AddArc(s, ac);
409   }
410
411   int64 AddState() { return GetImpl()->AddState(); }
412
413   bool DeleteArcs(int64 s, size_t n) { return GetImpl()->DeleteArcs(s, n); }
414
415   bool DeleteArcs(int64 s) { return GetImpl()->DeleteArcs(s); }
416
417   bool DeleteStates(const std::vector<int64> &dstates) {
418     return GetImpl()->DeleteStates(dstates);
419   }
420
421   void DeleteStates() { GetImpl()->DeleteStates(); }
422
423   SymbolTable *MutableInputSymbols() {
424     return GetImpl()->MutableInputSymbols();
425   }
426
427   SymbolTable *MutableOutputSymbols() {
428     return GetImpl()->MutableOutputSymbols();
429   }
430
431   int64 NumStates() const { return GetImpl()->NumStates(); }
432
433   bool ReserveArcs(int64 s, size_t n) { return GetImpl()->ReserveArcs(s, n); }
434
435   void ReserveStates(int64 s) { GetImpl()->ReserveStates(s); }
436
437   static MutableFstClass *Read(const string &fname, bool convert = false);
438
439   void SetInputSymbols(SymbolTable *isyms) {
440     GetImpl()->SetInputSymbols(isyms);
441   }
442
443   bool SetFinal(int64 s, const WeightClass &weight) {
444     if (!WeightTypesMatch(weight, "SetFinal")) return false;
445     return GetImpl()->SetFinal(s, weight);
446   }
447
448   void SetOutputSymbols(SymbolTable *osyms) {
449     GetImpl()->SetOutputSymbols(osyms);
450   }
451
452   void SetProperties(uint64 props, uint64 mask) {
453     GetImpl()->SetProperties(props, mask);
454   }
455
456   bool SetStart(int64 s) { return GetImpl()->SetStart(s); }
457
458   template <class Arc>
459   explicit MutableFstClass(const MutableFst<Arc> &fst) : FstClass(fst) {}
460
461   // These methods are required by IO registration.
462
463   template <class Arc>
464   static FstClassImplBase *Convert(const FstClass &other) {
465     FSTERROR() << "Doesn't make sense to convert any class to type "
466                << "MutableFstClass";
467     return nullptr;
468   }
469
470   template <class Arc>
471   static FstClassImplBase *Create() {
472     FSTERROR() << "Doesn't make sense to create a MutableFstClass with a "
473                << "particular arc type";
474     return nullptr;
475   }
476
477   template <class Arc>
478   MutableFst<Arc> *GetMutableFst() {
479     Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>());
480     MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst);
481     return mfst;
482   }
483
484   template <class Arc>
485   static MutableFstClass *Read(std::istream &stream,
486                                const FstReadOptions &opts) {
487     std::unique_ptr<MutableFst<Arc>> mfst(MutableFst<Arc>::Read(stream, opts));
488     return mfst ? new MutableFstClass(*mfst) : nullptr;
489   }
490
491  protected:
492   explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) {}
493 };
494
495 class VectorFstClass : public MutableFstClass {
496  public:
497   explicit VectorFstClass(FstClassImplBase *impl) : MutableFstClass(impl) {}
498
499   explicit VectorFstClass(const FstClass &other);
500
501   explicit VectorFstClass(const string &arc_type);
502
503   static VectorFstClass *Read(const string &fname);
504
505   template <class Arc>
506   static VectorFstClass *Read(std::istream &stream,
507                               const FstReadOptions &opts) {
508     std::unique_ptr<VectorFst<Arc>> mfst(VectorFst<Arc>::Read(stream, opts));
509     return mfst ? new VectorFstClass(*mfst) : nullptr;
510   }
511
512   template <class Arc>
513   explicit VectorFstClass(const VectorFst<Arc> &fst) : MutableFstClass(fst) {}
514
515   template <class Arc>
516   static FstClassImplBase *Convert(const FstClass &other) {
517     return new FstClassImpl<Arc>(new VectorFst<Arc>(*other.GetFst<Arc>()),
518                                  true);
519   }
520
521   template <class Arc>
522   static FstClassImplBase *Create() {
523     return new FstClassImpl<Arc>(new VectorFst<Arc>(), true);
524   }
525 };
526
527 }  // namespace script
528 }  // namespace fst
529
530 #endif  // FST_SCRIPT_FST_CLASS_H_