1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 #ifndef FST_SCRIPT_FST_CLASS_H_
5 #define FST_SCRIPT_FST_CLASS_H_
10 #include <type_traits>
12 #include <fst/expanded-fst.h>
14 #include <fst/mutable-fst.h>
15 #include <fst/vector-fst.h>
16 #include <fst/script/arc-class.h>
17 #include <fst/script/weight-class.h>
19 // Classes to support "boxing" all existing types of FST arcs in a single
20 // FstClass which hides the arc types. This allows clients to load
21 // and work with FSTs without knowing the arc type. These classes are only
22 // recommended for use in high-level scripting applications. Most users should
23 // use the lower-level templated versions corresponding to these classes.
28 // Abstract base class defining the set of functionalities implemented in all
29 // impls and passed through by all bases. Below FstClassBase the class
30 // hierarchy bifurcates; FstClassImplBase serves as the base class for all
31 // implementations (of which FstClassImpl is currently the only one) and
32 // FstClass serves as the base class for all interfaces.
36 virtual const string &ArcType() const = 0;
37 virtual WeightClass Final(int64) const = 0;
38 virtual const string &FstType() const = 0;
39 virtual const SymbolTable *InputSymbols() const = 0;
40 virtual size_t NumArcs(int64) const = 0;
41 virtual size_t NumInputEpsilons(int64) const = 0;
42 virtual size_t NumOutputEpsilons(int64) const = 0;
43 virtual const SymbolTable *OutputSymbols() const = 0;
44 virtual uint64 Properties(uint64, bool) const = 0;
45 virtual int64 Start() const = 0;
46 virtual const string &WeightType() const = 0;
47 virtual bool ValidStateId(int64) const = 0;
48 virtual bool Write(const string &) const = 0;
49 virtual bool Write(std::ostream &, const string &) const = 0;
50 virtual ~FstClassBase() {}
53 // Adds all the MutableFst methods.
54 class FstClassImplBase : public FstClassBase {
56 virtual bool AddArc(int64, const ArcClass &) = 0;
57 virtual int64 AddState() = 0;
58 virtual FstClassImplBase *Copy() = 0;
59 virtual bool DeleteArcs(int64, size_t) = 0;
60 virtual bool DeleteArcs(int64) = 0;
61 virtual bool DeleteStates(const std::vector<int64> &) = 0;
62 virtual void DeleteStates() = 0;
63 virtual SymbolTable *MutableInputSymbols() = 0;
64 virtual SymbolTable *MutableOutputSymbols() = 0;
65 virtual int64 NumStates() const = 0;
66 virtual bool ReserveArcs(int64, size_t) = 0;
67 virtual void ReserveStates(int64) = 0;
68 virtual void SetInputSymbols(SymbolTable *) = 0;
69 virtual bool SetFinal(int64, const WeightClass &) = 0;
70 virtual void SetOutputSymbols(SymbolTable *) = 0;
71 virtual void SetProperties(uint64, uint64) = 0;
72 virtual bool SetStart(int64) = 0;
73 ~FstClassImplBase() override {}
76 // Containiner class wrapping an Fst<Arc>, hiding its arc type. Whether this
77 // Fst<Arc> pointer refers to a special kind of FST (e.g. a MutableFst) is
78 // known by the type of interface class that owns the pointer to this
82 class FstClassImpl : public FstClassImplBase {
84 explicit FstClassImpl(Fst<Arc> *impl, bool should_own = false)
85 : impl_(should_own ? impl : impl->Copy()) {}
87 explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) {}
89 // Warning: calling this method casts the FST to a mutable FST.
90 bool AddArc(int64 s, const ArcClass &ac) final {
91 if (!ValidStateId(s)) return false;
92 // Note that we do not check that the destination state is valid, so users
93 // can add arcs before they add the corresponding states. Verify can be
94 // used to determine whether any arc has a nonexisting destination.
95 Arc arc(ac.ilabel, ac.olabel, *ac.weight.GetWeight<typename Arc::Weight>(),
97 static_cast<MutableFst<Arc> *>(impl_.get())->AddArc(s, arc);
101 // Warning: calling this method casts the FST to a mutable FST.
102 int64 AddState() final {
103 return static_cast<MutableFst<Arc> *>(impl_.get())->AddState();
106 const string &ArcType() const final { return Arc::Type(); }
108 FstClassImpl *Copy() final { return new FstClassImpl<Arc>(impl_.get()); }
110 // Warning: calling this method casts the FST to a mutable FST.
111 bool DeleteArcs(int64 s, size_t n) final {
112 if (!ValidStateId(s)) return false;
113 static_cast<MutableFst<Arc> *>(impl_.get())->DeleteArcs(s, n);
117 // Warning: calling this method casts the FST to a mutable FST.
118 bool DeleteArcs(int64 s) final {
119 if (!ValidStateId(s)) return false;
120 static_cast<MutableFst<Arc> *>(impl_.get())->DeleteArcs(s);
124 // Warning: calling this method casts the FST to a mutable FST.
125 bool DeleteStates(const std::vector<int64> &dstates) final {
126 for (const auto &state : dstates)
127 if (!ValidStateId(state)) return false;
128 // Warning: calling this method with any integers beyond the precision of
129 // the underlying FST will result in truncation.
130 std::vector<typename Arc::StateId> typed_dstates(dstates.size());
131 std::copy(dstates.begin(), dstates.end(), typed_dstates.begin());
132 static_cast<MutableFst<Arc> *>(impl_.get())->DeleteStates(typed_dstates);
136 // Warning: calling this method casts the FST to a mutable FST.
137 void DeleteStates() final {
138 static_cast<MutableFst<Arc> *>(impl_.get())->DeleteStates();
141 WeightClass Final(int64 s) const final {
142 if (!ValidStateId(s)) return WeightClass::NoWeight(WeightType());
143 WeightClass w(impl_->Final(s));
147 const string &FstType() const final { return impl_->Type(); }
149 const SymbolTable *InputSymbols() const final {
150 return impl_->InputSymbols();
153 // Warning: calling this method casts the FST to a mutable FST.
154 SymbolTable *MutableInputSymbols() final {
155 return static_cast<MutableFst<Arc> *>(impl_.get())->MutableInputSymbols();
158 // Warning: calling this method casts the FST to a mutable FST.
159 SymbolTable *MutableOutputSymbols() final {
160 return static_cast<MutableFst<Arc> *>(impl_.get())->MutableOutputSymbols();
163 // Signals failure by returning size_t max.
164 size_t NumArcs(int64 s) const final {
165 return ValidStateId(s) ? impl_->NumArcs(s)
166 : std::numeric_limits<size_t>::max();
169 // Signals failure by returning size_t max.
170 size_t NumInputEpsilons(int64 s) const final {
171 return ValidStateId(s) ? impl_->NumInputEpsilons(s)
172 : std::numeric_limits<size_t>::max();
175 // Signals failure by returning size_t max.
176 size_t NumOutputEpsilons(int64 s) const final {
177 return ValidStateId(s) ? impl_->NumOutputEpsilons(s)
178 : std::numeric_limits<size_t>::max();
181 // Warning: calling this method casts the FST to a mutable FST.
182 int64 NumStates() const final {
183 return static_cast<MutableFst<Arc> *>(impl_.get())->NumStates();
186 uint64 Properties(uint64 mask, bool test) const final {
187 return impl_->Properties(mask, test);
190 // Warning: calling this method casts the FST to a mutable FST.
191 bool ReserveArcs(int64 s, size_t n) final {
192 if (!ValidStateId(s)) return false;
193 static_cast<MutableFst<Arc> *>(impl_.get())->ReserveArcs(s, n);
197 // Warning: calling this method casts the FST to a mutable FST.
198 void ReserveStates(int64 s) final {
199 static_cast<MutableFst<Arc> *>(impl_.get())->ReserveStates(s);
202 const SymbolTable *OutputSymbols() const final {
203 return impl_->OutputSymbols();
206 // Warning: calling this method casts the FST to a mutable FST.
207 void SetInputSymbols(SymbolTable *isyms) final {
208 static_cast<MutableFst<Arc> *>(impl_.get())->SetInputSymbols(isyms);
211 // Warning: calling this method casts the FST to a mutable FST.
212 bool SetFinal(int64 s, const WeightClass &weight) final {
213 if (!ValidStateId(s)) return false;
214 static_cast<MutableFst<Arc> *>(impl_.get())
215 ->SetFinal(s, *weight.GetWeight<typename Arc::Weight>());
219 // Warning: calling this method casts the FST to a mutable FST.
220 void SetOutputSymbols(SymbolTable *osyms) final {
221 static_cast<MutableFst<Arc> *>(impl_.get())->SetOutputSymbols(osyms);
224 // Warning: calling this method casts the FST to a mutable FST.
225 void SetProperties(uint64 props, uint64 mask) final {
226 static_cast<MutableFst<Arc> *>(impl_.get())->SetProperties(props, mask);
229 // Warning: calling this method casts the FST to a mutable FST.
230 bool SetStart(int64 s) final {
231 if (!ValidStateId(s)) return false;
232 static_cast<MutableFst<Arc> *>(impl_.get())->SetStart(s);
236 int64 Start() const final { return impl_->Start(); }
238 bool ValidStateId(int64 s) const final {
239 // This cowardly refuses to count states if the FST is not yet expanded.
240 if (!Properties(kExpanded, true)) {
241 FSTERROR() << "Cannot get number of states for unexpanded FST";
244 // If the FST is already expanded, CountStates calls NumStates.
245 if (s < 0 || s >= CountStates(*impl_)) {
246 FSTERROR() << "State ID " << s << " not valid";
252 const string &WeightType() const final { return Arc::Weight::Type(); }
254 bool Write(const string &fname) const final { return impl_->Write(fname); }
256 bool Write(std::ostream &ostr, const string &fname) const final {
257 const FstWriteOptions opts(fname);
258 return impl_->Write(ostr, opts);
261 ~FstClassImpl() override {}
263 Fst<Arc> *GetImpl() const { return impl_.get(); }
266 std::unique_ptr<Fst<Arc>> impl_;
269 // BASE CLASS DEFINITIONS
271 class MutableFstClass;
273 class FstClass : public FstClassBase {
275 FstClass() : impl_(nullptr) {}
278 explicit FstClass(const Fst<Arc> &fst) : impl_(new FstClassImpl<Arc>(fst)) {}
280 FstClass(const FstClass &other)
281 : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {}
283 FstClass &operator=(const FstClass &other) {
284 impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy());
288 WeightClass Final(int64 s) const final { return impl_->Final(s); }
290 const string &ArcType() const final { return impl_->ArcType(); }
292 const string &FstType() const final { return impl_->FstType(); }
294 const SymbolTable *InputSymbols() const final {
295 return impl_->InputSymbols();
298 size_t NumArcs(int64 s) const final { return impl_->NumArcs(s); }
300 size_t NumInputEpsilons(int64 s) const final {
301 return impl_->NumInputEpsilons(s);
304 size_t NumOutputEpsilons(int64 s) const final {
305 return impl_->NumOutputEpsilons(s);
308 const SymbolTable *OutputSymbols() const final {
309 return impl_->OutputSymbols();
312 uint64 Properties(uint64 mask, bool test) const final {
313 // Special handling for FSTs with a null impl.
314 if (!impl_) return kError & mask;
315 return impl_->Properties(mask, test);
318 static FstClass *Read(const string &fname);
320 static FstClass *Read(std::istream &istrm, const string &source);
322 int64 Start() const final { return impl_->Start(); }
324 bool ValidStateId(int64 s) const final { return impl_->ValidStateId(s); }
326 const string &WeightType() const final { return impl_->WeightType(); }
328 // Helper that logs an ERROR if the weight type of an FST and a WeightClass
331 bool WeightTypesMatch(const WeightClass &weight, const string &op_name) const;
333 bool Write(const string &fname) const final { return impl_->Write(fname); }
335 bool Write(std::ostream &ostr, const string &fname) const final {
336 return impl_->Write(ostr, fname);
339 ~FstClass() override {}
341 // These methods are required by IO registration.
344 static FstClassImplBase *Convert(const FstClass &other) {
345 FSTERROR() << "Doesn't make sense to convert any class to type FstClass";
350 static FstClassImplBase *Create() {
351 FSTERROR() << "Doesn't make sense to create an FstClass with a "
352 << "particular arc type";
357 const Fst<Arc> *GetFst() const {
358 if (Arc::Type() != ArcType()) {
361 FstClassImpl<Arc> *typed_impl =
362 static_cast<FstClassImpl<Arc> *>(impl_.get());
363 return typed_impl->GetImpl();
368 static FstClass *Read(std::istream &stream, const FstReadOptions &opts) {
370 LOG(ERROR) << "FstClass::Read: Options header not specified";
373 const FstHeader &hdr = *opts.header;
374 if (hdr.Properties() & kMutable) {
375 return ReadTypedFst<MutableFstClass, MutableFst<Arc>>(stream, opts);
377 return ReadTypedFst<FstClass, Fst<Arc>>(stream, opts);
382 explicit FstClass(FstClassImplBase *impl) : impl_(impl) {}
384 const FstClassImplBase *GetImpl() const { return impl_.get(); }
386 FstClassImplBase *GetImpl() { return impl_.get(); }
388 // Generic template method for reading an arc-templated FST of type
389 // UnderlyingT, and returning it wrapped as FstClassT, with appropriat
390 // error checking. Called from arc-templated Read() static methods.
391 template <class FstClassT, class UnderlyingT>
392 static FstClassT *ReadTypedFst(std::istream &stream,
393 const FstReadOptions &opts) {
394 std::unique_ptr<UnderlyingT> u(UnderlyingT::Read(stream, opts));
395 return u ? new FstClassT(*u) : nullptr;
399 std::unique_ptr<FstClassImplBase> impl_;
402 // Specific types of FstClass with special properties
404 class MutableFstClass : public FstClass {
406 bool AddArc(int64 s, const ArcClass &ac) {
407 if (!WeightTypesMatch(ac.weight, "AddArc")) return false;
408 return GetImpl()->AddArc(s, ac);
411 int64 AddState() { return GetImpl()->AddState(); }
413 bool DeleteArcs(int64 s, size_t n) { return GetImpl()->DeleteArcs(s, n); }
415 bool DeleteArcs(int64 s) { return GetImpl()->DeleteArcs(s); }
417 bool DeleteStates(const std::vector<int64> &dstates) {
418 return GetImpl()->DeleteStates(dstates);
421 void DeleteStates() { GetImpl()->DeleteStates(); }
423 SymbolTable *MutableInputSymbols() {
424 return GetImpl()->MutableInputSymbols();
427 SymbolTable *MutableOutputSymbols() {
428 return GetImpl()->MutableOutputSymbols();
431 int64 NumStates() const { return GetImpl()->NumStates(); }
433 bool ReserveArcs(int64 s, size_t n) { return GetImpl()->ReserveArcs(s, n); }
435 void ReserveStates(int64 s) { GetImpl()->ReserveStates(s); }
437 static MutableFstClass *Read(const string &fname, bool convert = false);
439 void SetInputSymbols(SymbolTable *isyms) {
440 GetImpl()->SetInputSymbols(isyms);
443 bool SetFinal(int64 s, const WeightClass &weight) {
444 if (!WeightTypesMatch(weight, "SetFinal")) return false;
445 return GetImpl()->SetFinal(s, weight);
448 void SetOutputSymbols(SymbolTable *osyms) {
449 GetImpl()->SetOutputSymbols(osyms);
452 void SetProperties(uint64 props, uint64 mask) {
453 GetImpl()->SetProperties(props, mask);
456 bool SetStart(int64 s) { return GetImpl()->SetStart(s); }
459 explicit MutableFstClass(const MutableFst<Arc> &fst) : FstClass(fst) {}
461 // These methods are required by IO registration.
464 static FstClassImplBase *Convert(const FstClass &other) {
465 FSTERROR() << "Doesn't make sense to convert any class to type "
466 << "MutableFstClass";
471 static FstClassImplBase *Create() {
472 FSTERROR() << "Doesn't make sense to create a MutableFstClass with a "
473 << "particular arc type";
478 MutableFst<Arc> *GetMutableFst() {
479 Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>());
480 MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst);
485 static MutableFstClass *Read(std::istream &stream,
486 const FstReadOptions &opts) {
487 std::unique_ptr<MutableFst<Arc>> mfst(MutableFst<Arc>::Read(stream, opts));
488 return mfst ? new MutableFstClass(*mfst) : nullptr;
492 explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) {}
495 class VectorFstClass : public MutableFstClass {
497 explicit VectorFstClass(FstClassImplBase *impl) : MutableFstClass(impl) {}
499 explicit VectorFstClass(const FstClass &other);
501 explicit VectorFstClass(const string &arc_type);
503 static VectorFstClass *Read(const string &fname);
506 static VectorFstClass *Read(std::istream &stream,
507 const FstReadOptions &opts) {
508 std::unique_ptr<VectorFst<Arc>> mfst(VectorFst<Arc>::Read(stream, opts));
509 return mfst ? new VectorFstClass(*mfst) : nullptr;
513 explicit VectorFstClass(const VectorFst<Arc> &fst) : MutableFstClass(fst) {}
516 static FstClassImplBase *Convert(const FstClass &other) {
517 return new FstClassImpl<Arc>(new VectorFst<Arc>(*other.GetFst<Arc>()),
522 static FstClassImplBase *Create() {
523 return new FstClassImpl<Arc>(new VectorFst<Arc>(), true);
527 } // namespace script
530 #endif // FST_SCRIPT_FST_CLASS_H_