9f0cc3ca0ebc8cc739b96d09ec5d792ba5f45cac
[platform/upstream/openfst.git] / src / include / fst / fst.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // FST abstract base class definition, state and arc iterator interface, and
5 // suggested base implementation.
6
7 #ifndef FST_LIB_FST_H_
8 #define FST_LIB_FST_H_
9
10 #include <sys/types.h>
11
12 #include <cmath>
13 #include <cstddef>
14
15 #include <iostream>
16 #include <memory>
17 #include <sstream>
18 #include <string>
19 #include <utility>
20
21 #include <fst/compat.h>
22 #include <fst/types.h>
23 #include <fst/log.h>
24 #include <fstream>
25
26 #include <fst/arc.h>
27 #include <fst/memory.h>
28 #include <fst/properties.h>
29 #include <fst/register.h>
30 #include <fst/symbol-table.h>
31 #include <fst/util.h>
32
33
34 DECLARE_bool(fst_align);
35
36 namespace fst {
37
38 bool IsFstHeader(std::istream &, const string &);
39
40 class FstHeader;
41 template <class Arc>
42
43 struct StateIteratorData;
44 template <class Arc>
45
46 struct ArcIteratorData;
47
48 template <class Arc>
49 class MatcherBase;
50
51 struct FstReadOptions {
52   // FileReadMode(s) are advisory, there are many conditions than prevent a
53   // file from being mapped, READ mode will be selected in these cases with
54   // a warning indicating why it was chosen.
55   enum FileReadMode { READ, MAP };
56
57   string source;                // Where you're reading from.
58   const FstHeader *header;      // Pointer to FST header; if non-zero, use
59                                 // this info (don't read a stream header).
60   const SymbolTable *isymbols;  // Pointer to input symbols; if non-zero, use
61                                 // this info (read and skip stream isymbols)
62   const SymbolTable *osymbols;  // Pointer to output symbols; if non-zero, use
63                                 // this info (read and skip stream osymbols)
64   FileReadMode mode;            // Read or map files (advisory, if possible)
65   bool read_isymbols;           // Read isymbols, if any (default: true).
66   bool read_osymbols;           // Read osymbols, if any (default: true).
67
68   explicit FstReadOptions(const string &source = "<unspecified>",
69                           const FstHeader *header = nullptr,
70                           const SymbolTable *isymbols = nullptr,
71                           const SymbolTable *osymbols = nullptr);
72
73   explicit FstReadOptions(const string &source, const SymbolTable *isymbols,
74                           const SymbolTable *osymbols = nullptr);
75
76   // Helper function to convert strings FileReadModes into their enum value.
77   static FileReadMode ReadMode(const string &mode);
78
79   // Outputs a debug string for the FstReadOptions object.
80   string DebugString() const;
81 };
82
83 struct FstWriteOptions {
84   string source;        // Where you're writing to.
85   bool write_header;    // Write the header?
86   bool write_isymbols;  // Write input symbols?
87   bool write_osymbols;  // Write output symbols?
88   bool align;           // Write data aligned (may fail on pipes)?
89   bool stream_write;    // Avoid seek operations in writing.
90
91   explicit FstWriteOptions(const string &source = "<unspecifed>",
92                            bool write_header = true, bool write_isymbols = true,
93                            bool write_osymbols = true,
94                            bool align = FLAGS_fst_align,
95                            bool stream_write = false)
96       : source(source),
97         write_header(write_header),
98         write_isymbols(write_isymbols),
99         write_osymbols(write_osymbols),
100         align(align),
101         stream_write(stream_write) {}
102 };
103
104 // Header class.
105 //
106 // This is the recommended file header representation.
107
108 class FstHeader {
109  public:
110   enum {
111     HAS_ISYMBOLS = 0x1,  // Has input symbol table.
112     HAS_OSYMBOLS = 0x2,  // Has output symbol table.
113     IS_ALIGNED = 0x4,    // Memory-aligned (where appropriate).
114   } Flags;
115
116   FstHeader() : version_(0), flags_(0), properties_(0), start_(-1),
117       numstates_(0), numarcs_(0) {}
118
119   const string &FstType() const { return fsttype_; }
120
121   const string &ArcType() const { return arctype_; }
122
123   int32 Version() const { return version_; }
124
125   int32 GetFlags() const { return flags_; }
126
127   uint64 Properties() const { return properties_; }
128
129   int64 Start() const { return start_; }
130
131   int64 NumStates() const { return numstates_; }
132
133   int64 NumArcs() const { return numarcs_; }
134
135   void SetFstType(const string &type) { fsttype_ = type; }
136
137   void SetArcType(const string &type) { arctype_ = type; }
138
139   void SetVersion(int32 version) { version_ = version; }
140
141   void SetFlags(int32 flags) { flags_ = flags; }
142
143   void SetProperties(uint64 properties) { properties_ = properties; }
144
145   void SetStart(int64 start) { start_ = start; }
146
147   void SetNumStates(int64 numstates) { numstates_ = numstates; }
148
149   void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; }
150
151   bool Read(std::istream &strm, const string &source,
152             bool rewind = false);
153
154   bool Write(std::ostream &strm, const string &source) const;
155
156   // Outputs a debug string for the FstHeader object.
157   string DebugString() const;
158
159  private:
160   string fsttype_;     // E.g. "vector".
161   string arctype_;     // E.g. "standard".
162   int32 version_;      // Type version number.
163   int32 flags_;        // File format bits.
164   uint64 properties_;  // FST property bits.
165   int64 start_;        // Start state.
166   int64 numstates_;    // # of states.
167   int64 numarcs_;      // # of arcs.
168 };
169
170 // Specifies matcher action.
171 enum MatchType {
172   MATCH_INPUT = 1,   // Match input label.
173   MATCH_OUTPUT = 2,  // Match output label.
174   MATCH_BOTH = 3,    // Match input or output label.
175   MATCH_NONE = 4,    // Match nothing.
176   MATCH_UNKNOWN = 5
177 };  // Otherwise, match type unknown.
178
179 constexpr int kNoStateId = -1;  // Not a valid state ID.
180 constexpr int kNoLabel = -1;    // Not a valid label.
181
182 // A generic FST, templated on the arc definition, with common-demoninator
183 // methods (use StateIterator and ArcIterator to iterate over its states and
184 // arcs).
185 template <class A>
186 class Fst {
187  public:
188   using Arc = A;
189   using StateId = typename Arc::StateId;
190   using Weight = typename Arc::Weight;
191
192   virtual ~Fst() {}
193
194   // Initial state.
195   virtual StateId Start() const = 0;
196
197   // State's final weight.
198   virtual Weight Final(StateId) const = 0;
199
200   // State's arc count.
201   virtual size_t NumArcs(StateId) const = 0;
202
203   // State's input epsilon count.
204   virtual size_t NumInputEpsilons(StateId) const = 0;
205
206   // State's output epsilon count.
207   virtual size_t NumOutputEpsilons(StateId) const = 0;
208
209   // Property bits. If test = false, return stored properties bits for mask
210   // (some possibly unknown); if test = true, return property bits for mask
211   // (computing o.w. unknown).
212   virtual uint64 Properties(uint64 mask, bool test) const = 0;
213
214   // FST type name.
215   virtual const string &Type() const = 0;
216
217   // Gets a copy of this Fst. The copying behaves as follows:
218   //
219   // (1) The copying is constant time if safe = false or if safe = true
220   // and is on an otherwise unaccessed FST.
221   //
222   // (2) If safe = true, the copy is thread-safe in that the original
223   // and copy can be safely accessed (but not necessarily mutated) by
224   // separate threads. For some FST types, 'Copy(true)' should only be
225   // called on an FST that has not otherwise been accessed. Behavior is
226   // otherwise undefined.
227   //
228   // (3) If a MutableFst is copied and then mutated, then the original is
229   // unmodified and vice versa (often by a copy-on-write on the initial
230   // mutation, which may not be constant time).
231   virtual Fst<Arc> *Copy(bool safe = false) const = 0;
232
233   // Reads an FST from an input stream; returns nullptr on error.
234   static Fst<Arc> *Read(std::istream &strm, const FstReadOptions &opts) {
235     FstReadOptions ropts(opts);
236     FstHeader hdr;
237     if (ropts.header) {
238       hdr = *opts.header;
239     } else {
240       if (!hdr.Read(strm, opts.source)) return nullptr;
241       ropts.header = &hdr;
242     }
243     const auto &fst_type = hdr.FstType();
244     const auto reader = FstRegister<Arc>::GetRegister()->GetReader(fst_type);
245     if (!reader) {
246       LOG(ERROR) << "Fst::Read: Unknown FST type " << fst_type
247                  << " (arc type = " << Arc::Type() << "): " << ropts.source;
248       return nullptr;
249     }
250     return reader(strm, ropts);
251   }
252
253   // Reads an FST from a file; returns nullptr on error. An empty filename
254   // results in reading from standard input.
255   static Fst<Arc> *Read(const string &filename) {
256     if (!filename.empty()) {
257       std::ifstream strm(filename,
258                               std::ios_base::in | std::ios_base::binary);
259       if (!strm) {
260         LOG(ERROR) << "Fst::Read: Can't open file: " << filename;
261         return nullptr;
262       }
263       return Read(strm, FstReadOptions(filename));
264     } else {
265       return Read(std::cin, FstReadOptions("standard input"));
266     }
267   }
268
269   // Writes an FST to an output stream; returns false on error.
270   virtual bool Write(std::ostream &strm, const FstWriteOptions &opts) const {
271     LOG(ERROR) << "Fst::Write: No write stream method for " << Type()
272                << " FST type";
273     return false;
274   }
275
276   // Writes an FST to a file; returns false on error; an empty filename
277   // results in writing to standard output.
278   virtual bool Write(const string &filename) const {
279     LOG(ERROR) << "Fst::Write: No write filename method for " << Type()
280                << " FST type";
281     return false;
282   }
283
284   // Returns input label symbol table; return nullptr if not specified.
285   virtual const SymbolTable *InputSymbols() const = 0;
286
287   // Return output label symbol table; return nullptr if not specified.
288   virtual const SymbolTable *OutputSymbols() const = 0;
289
290   // For generic state iterator construction (not normally called directly by
291   // users).
292   virtual void InitStateIterator(StateIteratorData<Arc> *data) const = 0;
293
294   // For generic arc iterator construction (not normally called directly by
295   // users).
296   virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const = 0;
297
298   // For generic matcher construction (not normally called directly by users).
299   virtual MatcherBase<Arc> *InitMatcher(MatchType match_type) const;
300
301  protected:
302   bool WriteFile(const string &filename) const {
303     if (!filename.empty()) {
304       std::ofstream strm(filename,
305                                std::ios_base::out | std::ios_base::binary);
306       if (!strm) {
307         LOG(ERROR) << "Fst::Write: Can't open file: " << filename;
308         return false;
309       }
310       bool val = Write(strm, FstWriteOptions(filename));
311       if (!val) LOG(ERROR) << "Fst::Write failed: " << filename;
312       return val;
313     } else {
314       return Write(std::cout, FstWriteOptions("standard output"));
315     }
316   }
317 };
318
319 // A useful alias when using StdArc.
320 using StdFst = Fst<StdArc>;
321
322 // State and arc iterator definitions.
323 //
324 // State iterator interface templated on the Arc definition; used for
325 // StateIterator specializations returned by the InitStateIterator FST method.
326 template <class Arc>
327 class StateIteratorBase {
328  public:
329   using StateId = typename Arc::StateId;
330
331   virtual ~StateIteratorBase() {}
332
333   // End of iterator?
334   virtual bool Done() const = 0;
335   // Returns current state (when !Done()).
336   virtual StateId Value() const = 0;
337   // Advances to next state (when !Done()).
338   virtual void Next() = 0;
339   // Resets to initial condition.
340   virtual void Reset() = 0;
341 };
342
343 // StateIterator initialization data.
344
345 template <class Arc>
346 struct StateIteratorData {
347   using StateId = typename Arc::StateId;
348
349   // Specialized iterator if non-zero.
350   StateIteratorBase<Arc> *base;
351   // Otherwise, the total number of states.
352   StateId nstates;
353
354   StateIteratorData() : base(nullptr), nstates(0) {}
355
356   StateIteratorData(const StateIteratorData &) = delete;
357   StateIteratorData &operator=(const StateIteratorData &) = delete;
358 };
359
360 // Generic state iterator, templated on the FST definition (a wrapper
361 // around a pointer to a specific one). Here is a typical use:
362 //
363 //   for (StateIterator<StdFst> siter(fst);
364 //        !siter.Done();
365 //        siter.Next()) {
366 //     StateId s = siter.Value();
367 //     ...
368 //   }
369 template <class FST>
370 class StateIterator {
371  public:
372   using Arc = typename FST::Arc;
373   using StateId = typename Arc::StateId;
374
375   explicit StateIterator(const FST &fst) : s_(0) {
376     fst.InitStateIterator(&data_);
377   }
378
379   ~StateIterator() { delete data_.base; }
380
381   bool Done() const {
382     return data_.base ? data_.base->Done() : s_ >= data_.nstates;
383   }
384
385   StateId Value() const { return data_.base ? data_.base->Value() : s_; }
386
387   void Next() {
388     if (data_.base) {
389       data_.base->Next();
390     } else {
391       ++s_;
392     }
393   }
394
395   void Reset() {
396     if (data_.base) {
397       data_.base->Reset();
398     } else {
399       s_ = 0;
400     }
401   }
402
403  private:
404   StateIteratorData<Arc> data_;
405   StateId s_;
406 };
407
408 // Flags to control the behavior on an arc iterator.
409 static constexpr uint32 kArcILabelValue =
410     0x0001;  // Value() gives valid ilabel.
411 static constexpr uint32 kArcOLabelValue = 0x0002;  //  "       "     " olabel.
412 static constexpr uint32 kArcWeightValue = 0x0004;  //  "       "     " weight.
413 static constexpr uint32 kArcNextStateValue =
414     0x0008;                                    //  "       "     " nextstate.
415 static constexpr uint32 kArcNoCache = 0x0010;  // No need to cache arcs.
416
417 static constexpr uint32 kArcValueFlags =
418     kArcILabelValue | kArcOLabelValue | kArcWeightValue | kArcNextStateValue;
419
420 static constexpr uint32 kArcFlags = kArcValueFlags | kArcNoCache;
421
422 // Arc iterator interface, templated on the arc definition; used for arc
423 // iterator specializations that are returned by the InitArcIterator FST method.
424 template <class Arc>
425 class ArcIteratorBase {
426  public:
427   using StateId = typename Arc::StateId;
428
429   virtual ~ArcIteratorBase() {}
430
431   // End of iterator?
432   virtual bool Done() const = 0;
433   // Returns current arc (when !Done()).
434   virtual const Arc &Value() const = 0;
435   // Advances to next arc (when !Done()).
436   virtual void Next() = 0;
437   // Returns current position.
438   virtual size_t Position() const = 0;
439   // Returns to initial condition.
440   virtual void Reset() = 0;
441   // Advances to arbitrary arc by position.
442   virtual void Seek(size_t) = 0;
443   // Returns current behavorial flags
444   virtual uint32 Flags() const = 0;
445   // Sets behavorial flags.
446   virtual void SetFlags(uint32, uint32) = 0;
447 };
448
449 // ArcIterator initialization data.
450 template <class Arc>
451 struct ArcIteratorData {
452   ArcIteratorData()
453       : base(nullptr), arcs(nullptr), narcs(0), ref_count(nullptr) {}
454
455   ArcIteratorData(const ArcIteratorData &) = delete;
456
457   ArcIteratorData &operator=(const ArcIteratorData &) = delete;
458
459   ArcIteratorBase<Arc> *base;  // Specialized iterator if non-zero.
460   const Arc *arcs;             // O.w. arcs pointer
461   size_t narcs;                // ... and arc count.
462   int *ref_count;              // ... and reference count if non-zero.
463 };
464
465 // Generic arc iterator, templated on the FST definition (a wrapper around a
466 // pointer to a specific one). Here is a typical use:
467 //
468 //   for (ArcIterator<StdFst> aiter(fst, s);
469 //        !aiter.Done();
470 //         aiter.Next()) {
471 //     StdArc &arc = aiter.Value();
472 //     ...
473 //   }
474 template <class FST>
475 class ArcIterator {
476  public:
477   using Arc = typename FST::Arc;
478   using StateId = typename Arc::StateId;
479
480   ArcIterator(const FST &fst, StateId s) : i_(0) {
481     fst.InitArcIterator(s, &data_);
482   }
483
484   explicit ArcIterator(const ArcIteratorData<Arc> &data) : data_(data), i_(0) {
485     if (data_.ref_count) ++(*data_.ref_count);
486   }
487
488   ~ArcIterator() {
489     if (data_.base) {
490       delete data_.base;
491     } else if (data_.ref_count) {
492       --(*data_.ref_count);
493     }
494   }
495
496   bool Done() const {
497     return data_.base ? data_.base->Done() : i_ >= data_.narcs;
498   }
499
500   const Arc &Value() const {
501     return data_.base ? data_.base->Value() : data_.arcs[i_];
502   }
503
504   void Next() {
505     if (data_.base) {
506       data_.base->Next();
507     } else {
508       ++i_;
509     }
510   }
511
512   void Reset() {
513     if (data_.base) {
514       data_.base->Reset();
515     } else {
516       i_ = 0;
517     }
518   }
519
520   void Seek(size_t a) {
521     if (data_.base) {
522       data_.base->Seek(a);
523     } else {
524       i_ = a;
525     }
526   }
527
528   size_t Position() const { return data_.base ? data_.base->Position() : i_; }
529
530   uint32 Flags() const {
531     if (data_.base) {
532       return data_.base->Flags();
533     } else {
534       return kArcValueFlags;
535     }
536   }
537
538   void SetFlags(uint32 flags, uint32 mask) {
539     if (data_.base) data_.base->SetFlags(flags, mask);
540   }
541
542  private:
543   ArcIteratorData<Arc> data_;
544   size_t i_;
545 };
546
547 }  // namespace fst
548
549 // ArcIterator placement operator new and destroy function; new needs to be in
550 // the global namespace.
551
552 template <class FST>
553 void *operator new(size_t size,
554                    fst::MemoryPool<fst::ArcIterator<FST>> *pool) {
555   return pool->Allocate();
556 }
557
558 namespace fst {
559
560 template <class FST>
561 void Destroy(ArcIterator<FST> *aiter, MemoryPool<ArcIterator<FST>> *pool) {
562   if (aiter) {
563     aiter->~ArcIterator<FST>();
564     pool->Free(aiter);
565   }
566 }
567
568 // Matcher definitions.
569
570 template <class Arc>
571 MatcherBase<Arc> *Fst<Arc>::InitMatcher(MatchType match_type) const {
572   return nullptr;  // One should just use the default matcher.
573 }
574
575 // FST accessors, useful in high-performance applications.
576
577 namespace internal {
578
579 // General case, requires non-abstract, 'final' methods. Use for inlining.
580
581 template <class F>
582 inline typename F::Arc::Weight Final(const F &fst, typename F::Arc::StateId s) {
583   return fst.F::Final(s);
584 }
585
586 template <class F>
587 inline ssize_t NumArcs(const F &fst, typename F::Arc::StateId s) {
588   return fst.F::NumArcs(s);
589 }
590
591 template <class F>
592 inline ssize_t NumInputEpsilons(const F &fst, typename F::Arc::StateId s) {
593   return fst.F::NumInputEpsilons(s);
594 }
595
596 template <class F>
597 inline ssize_t NumOutputEpsilons(const F &fst, typename F::Arc::StateId s) {
598   return fst.F::NumOutputEpsilons(s);
599 }
600
601 // Fst<Arc> case, abstract methods.
602
603 template <class Arc>
604 inline typename Arc::Weight Final(const Fst<Arc> &fst,
605                                   typename Arc::StateId s) {
606   return fst.Final(s);
607 }
608
609 template <class Arc>
610 inline size_t NumArcs(const Fst<Arc> &fst, typename Arc::StateId s) {
611   return fst.NumArcs(s);
612 }
613
614 template <class Arc>
615 inline size_t NumInputEpsilons(const Fst<Arc> &fst, typename Arc::StateId s) {
616   return fst.NumInputEpsilons(s);
617 }
618
619 template <class Arc>
620 inline size_t NumOutputEpsilons(const Fst<Arc> &fst, typename Arc::StateId s) {
621   return fst.NumOutputEpsilons(s);
622 }
623
624 }  // namespace internal
625
626 // FST implementation base.
627 //
628 // This is the recommended FST implementation base class. It will handle
629 // reference counts, property bits, type information and symbols.
630
631 namespace internal {
632
633 template <class Arc>
634 class FstImpl {
635  public:
636   using StateId = typename Arc::StateId;
637   using Weight = typename Arc::Weight;
638
639   FstImpl() : properties_(0), type_("null") {}
640
641   FstImpl(const FstImpl<Arc> &impl)
642       : properties_(impl.properties_),
643         type_(impl.type_),
644         isymbols_(impl.isymbols_ ? impl.isymbols_->Copy() : nullptr),
645         osymbols_(impl.osymbols_ ? impl.osymbols_->Copy() : nullptr) {}
646
647   virtual ~FstImpl() {}
648
649   const string &Type() const { return type_; }
650
651   void SetType(const string &type) { type_ = type; }
652
653   virtual uint64 Properties() const { return properties_; }
654
655   virtual uint64 Properties(uint64 mask) const { return properties_ & mask; }
656
657   void SetProperties(uint64 props) {
658     properties_ &= kError;  // kError can't be cleared.
659     properties_ |= props;
660   }
661
662   void SetProperties(uint64 props, uint64 mask) {
663     properties_ &= ~mask | kError;  // kError can't be cleared.
664     properties_ |= props & mask;
665   }
666
667   // Allows (only) setting error bit on const FST implementations.
668   void SetProperties(uint64 props, uint64 mask) const {
669     if (mask != kError) {
670       FSTERROR() << "FstImpl::SetProperties() const: Can only set kError";
671     }
672     properties_ |= kError;
673   }
674
675   const SymbolTable *InputSymbols() const { return isymbols_.get(); }
676
677   const SymbolTable *OutputSymbols() const { return osymbols_.get(); }
678
679   SymbolTable *InputSymbols() { return isymbols_.get(); }
680
681   SymbolTable *OutputSymbols() { return osymbols_.get(); }
682
683   void SetInputSymbols(const SymbolTable *isyms) {
684     isymbols_.reset(isyms ? isyms->Copy() : nullptr);
685   }
686
687   void SetOutputSymbols(const SymbolTable *osyms) {
688     osymbols_.reset(osyms ? osyms->Copy() : nullptr);
689   }
690
691   // Reads header and symbols from input stream, initializes FST, and returns
692   // the header. If opts.header is non-null, skips reading and uses the option
693   // value instead. If opts.[io]symbols is non-null, reads in (if present), but
694   // uses the option value.
695   bool ReadHeader(std::istream &strm, const FstReadOptions &opts,
696                   int min_version, FstHeader *hdr);
697
698   // Writes header and symbols to output stream. If opts.header is false, skips
699   // writing header. If opts.[io]symbols is false, skips writing those symbols.
700   // This method is needed for implementations that implement Write methods.
701   void WriteHeader(std::ostream &strm, const FstWriteOptions &opts,
702                    int version, FstHeader *hdr) const {
703     if (opts.write_header) {
704       hdr->SetFstType(type_);
705       hdr->SetArcType(Arc::Type());
706       hdr->SetVersion(version);
707       hdr->SetProperties(properties_);
708       int32 file_flags = 0;
709       if (isymbols_ && opts.write_isymbols) {
710         file_flags |= FstHeader::HAS_ISYMBOLS;
711       }
712       if (osymbols_ && opts.write_osymbols) {
713         file_flags |= FstHeader::HAS_OSYMBOLS;
714       }
715       if (opts.align) file_flags |= FstHeader::IS_ALIGNED;
716       hdr->SetFlags(file_flags);
717       hdr->Write(strm, opts.source);
718     }
719     if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm);
720     if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm);
721   }
722
723   // Writes out header and symbols to output stream. If opts.header is false,
724   // skips writing header. If opts.[io]symbols is false, skips writing those
725   // symbols. `type` is the FST type being written. This method is used in the
726   // cross-type serialization methods Fst::WriteFst.
727   static void WriteFstHeader(const Fst<Arc> &fst, std::ostream &strm,
728                              const FstWriteOptions &opts, int version,
729                              const string &type, uint64 properties,
730                              FstHeader *hdr) {
731     if (opts.write_header) {
732       hdr->SetFstType(type);
733       hdr->SetArcType(Arc::Type());
734       hdr->SetVersion(version);
735       hdr->SetProperties(properties);
736       int32 file_flags = 0;
737       if (fst.InputSymbols() && opts.write_isymbols) {
738         file_flags |= FstHeader::HAS_ISYMBOLS;
739       }
740       if (fst.OutputSymbols() && opts.write_osymbols) {
741         file_flags |= FstHeader::HAS_OSYMBOLS;
742       }
743       if (opts.align) file_flags |= FstHeader::IS_ALIGNED;
744       hdr->SetFlags(file_flags);
745       hdr->Write(strm, opts.source);
746     }
747     if (fst.InputSymbols() && opts.write_isymbols) {
748       fst.InputSymbols()->Write(strm);
749     }
750     if (fst.OutputSymbols() && opts.write_osymbols) {
751       fst.OutputSymbols()->Write(strm);
752     }
753   }
754
755   // In serialization routines where the header cannot be written until after
756   // the machine has been serialized, this routine can be called to seek to the
757   // beginning of the file an rewrite the header with updated fields. It
758   // repositions the file pointer back at the end of the file. Returns true on
759   // success, false on failure.
760   static bool UpdateFstHeader(const Fst<Arc> &fst, std::ostream &strm,
761                               const FstWriteOptions &opts, int version,
762                               const string &type, uint64 properties,
763                               FstHeader *hdr, size_t header_offset) {
764     strm.seekp(header_offset);
765     if (!strm) {
766       LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source;
767       return false;
768     }
769     WriteFstHeader(fst, strm, opts, version, type, properties, hdr);
770     if (!strm) {
771       LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source;
772       return false;
773     }
774     strm.seekp(0, std::ios_base::end);
775     if (!strm) {
776       LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source;
777       return false;
778     }
779     return true;
780   }
781
782  protected:
783   mutable uint64 properties_;  // Property bits.
784
785  private:
786   string type_;  // Unique name of FST class.
787   std::unique_ptr<SymbolTable> isymbols_;
788   std::unique_ptr<SymbolTable> osymbols_;
789 };
790
791 template <class Arc>
792 bool FstImpl<Arc>::ReadHeader(std::istream &strm, const FstReadOptions &opts,
793                               int min_version, FstHeader *hdr) {
794   if (opts.header) {
795     *hdr = *opts.header;
796   } else if (!hdr->Read(strm, opts.source)) {
797     return false;
798   }
799   if (FLAGS_v >= 2) {
800     LOG(INFO) << "FstImpl::ReadHeader: source: " << opts.source
801               << ", fst_type: " << hdr->FstType()
802               << ", arc_type: " << Arc::Type()
803               << ", version: " << hdr->Version()
804               << ", flags: " << hdr->GetFlags();
805   }
806   if (hdr->FstType() != type_) {
807     LOG(ERROR) << "FstImpl::ReadHeader: FST not of type " << type_
808                << ": " << opts.source;
809     return false;
810   }
811   if (hdr->ArcType() != Arc::Type()) {
812     LOG(ERROR) << "FstImpl::ReadHeader: Arc not of type " << Arc::Type()
813                << ": " << opts.source;
814     return false;
815   }
816   if (hdr->Version() < min_version) {
817     LOG(ERROR) << "FstImpl::ReadHeader: Obsolete " << type_
818                << " FST version: " << opts.source;
819     return false;
820   }
821   properties_ = hdr->Properties();
822   if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS) {
823     isymbols_.reset(SymbolTable::Read(strm, opts.source));
824   }
825   // Deletes input symbol table.
826   if (!opts.read_isymbols) SetInputSymbols(nullptr);
827   if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS) {
828     osymbols_.reset(SymbolTable::Read(strm, opts.source));
829   }
830   // Deletes output symbol table.
831   if (!opts.read_osymbols) SetOutputSymbols(nullptr);
832   if (opts.isymbols) {
833     isymbols_.reset(opts.isymbols->Copy());
834   }
835   if (opts.osymbols) {
836     osymbols_.reset(opts.osymbols->Copy());
837   }
838   return true;
839 }
840
841 }  // namespace internal
842
843 template <class Arc>
844 uint64 TestProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known);
845
846 // This is a helper class template useful for attaching an FST interface to
847 // its implementation, handling reference counting.
848 template <class Impl, class FST = Fst<typename Impl::Arc>>
849 class ImplToFst : public FST {
850  public:
851   using Arc = typename Impl::Arc;
852   using StateId = typename Arc::StateId;
853   using Weight = typename Arc::Weight;
854   using FST::operator=;
855
856   StateId Start() const override { return impl_->Start(); }
857
858   Weight Final(StateId s) const override { return impl_->Final(s); }
859
860   size_t NumArcs(StateId s) const override { return impl_->NumArcs(s); }
861
862   size_t NumInputEpsilons(StateId s) const override {
863     return impl_->NumInputEpsilons(s);
864   }
865
866   size_t NumOutputEpsilons(StateId s) const override {
867     return impl_->NumOutputEpsilons(s);
868   }
869
870   uint64 Properties(uint64 mask, bool test) const override {
871     if (test) {
872       uint64 knownprops, testprops = TestProperties(*this, mask, &knownprops);
873       impl_->SetProperties(testprops, knownprops);
874       return testprops & mask;
875     } else {
876       return impl_->Properties(mask);
877     }
878   }
879
880   const string &Type() const override { return impl_->Type(); }
881
882   const SymbolTable *InputSymbols() const override {
883     return impl_->InputSymbols();
884   }
885
886   const SymbolTable *OutputSymbols() const override {
887     return impl_->OutputSymbols();
888   }
889
890  protected:
891   explicit ImplToFst(std::shared_ptr<Impl> impl) : impl_(std::move(impl)) {}
892
893   // This constructor presumes there is a copy constructor for the
894   // implementation.
895   ImplToFst(const ImplToFst<Impl, FST> &fst, bool safe) {
896     if (safe) {
897       impl_ = std::make_shared<Impl>(*(fst.impl_));
898     } else {
899       impl_ = fst.impl_;
900     }
901   }
902
903   // Returns raw pointers to the shared object.
904   const Impl *GetImpl() const { return impl_.get(); }
905
906   Impl *GetMutableImpl() const { return impl_.get(); }
907
908   // Returns a ref-counted smart poiner to the implementation.
909   std::shared_ptr<Impl> GetSharedImpl() const { return impl_; }
910
911   bool Unique() const { return impl_.unique(); }
912
913   void SetImpl(std::shared_ptr<Impl> impl) { impl_ = impl; }
914
915  private:
916   template <class IFST, class OFST>
917   friend void Cast(const IFST &ifst, OFST *ofst);
918
919   std::shared_ptr<Impl> impl_;
920 };
921
922 // Converts FSTs by casting their implementations, where this makes sense
923 // (which excludes implementations with weight-dependent virtual methods).
924 // Must be a friend of the FST classes involved (currently the concrete FSTs:
925 // ConstFst, CompactFst, and VectorFst). This can only be safely used for arc
926 // types that have identical storage characteristics. As with an FST
927 // copy constructor and Copy() method, this is a constant time operation
928 // (but subject to copy-on-write if it is a MutableFst and modified).
929 template <class IFST, class OFST>
930 void Cast(const IFST &ifst, OFST *ofst) {
931   using OImpl = typename OFST::Impl;
932   ofst->impl_ = std::shared_ptr<OImpl>(ifst.impl_,
933       reinterpret_cast<OImpl *>(ifst.impl_.get()));
934 }
935
936 // FST serialization.
937
938 template <class Arc>
939 void FstToString(const Fst<Arc> &fst, string *result) {
940   std::ostringstream ostrm;
941   fst.Write(ostrm, FstWriteOptions("FstToString"));
942   *result = ostrm.str();
943 }
944
945 template <class Arc>
946 void FstToString(const Fst<Arc> &fst, string *result,
947                  const FstWriteOptions &options) {
948   std::ostringstream ostrm;
949   fst.Write(ostrm, options);
950   *result = ostrm.str();
951 }
952
953 template <class Arc>
954 Fst<Arc> *StringToFst(const string &s) {
955   std::istringstream istrm(s);
956   return Fst<Arc>::Read(istrm, FstReadOptions("StringToFst"));
957 }
958
959 }  // namespace fst
960
961 #endif  // FST_LIB_FST_H_