ced06de29513fa3b3fb6d81781344d58206218f7
[platform/upstream/openfst.git] / src / script / fst-class.cc
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // These classes are only recommended for use in high-level scripting
5 // applications. Most users should use the lower-level templated versions
6 // corresponding to these classes.
7
8 #include <istream>
9
10 #include <fst/log.h>
11
12 #include <fst/equal.h>
13 #include <fst/fst-decl.h>
14 #include <fst/reverse.h>
15 #include <fst/union.h>
16 #include <fst/script/fst-class.h>
17 #include <fst/script/register.h>
18
19 namespace fst {
20 namespace script {
21
22 // Registration.
23
24 REGISTER_FST_CLASSES(StdArc);
25 REGISTER_FST_CLASSES(LogArc);
26 REGISTER_FST_CLASSES(Log64Arc);
27
28 // FstClass methods.
29
30 template <class FstT>
31 FstT *ReadFst(std::istream &istrm, const string &fname) {
32   if (!istrm) {
33     LOG(ERROR) << "ReadFst: Can't open file: " << fname;
34     return nullptr;
35   }
36   FstHeader hdr;
37   if (!hdr.Read(istrm, fname)) return nullptr;
38   FstReadOptions read_options(fname, &hdr);
39   const auto arc_type = hdr.ArcType();
40   const auto reader =
41       IORegistration<FstT>::Register::GetRegister()->GetReader(arc_type);
42   if (!reader) {
43     LOG(ERROR) << "ReadFst: Unknown arc type: " << arc_type;
44     return nullptr;
45   }
46   return reader(istrm, read_options);
47 }
48
49 FstClass *FstClass::Read(const string &fname) {
50   if (!fname.empty()) {
51     std::ifstream istrm(fname, std::ios_base::in | std::ios_base::binary);
52     return ReadFst<FstClass>(istrm, fname);
53   } else {
54     return ReadFst<FstClass>(std::cin, "standard input");
55   }
56 }
57
58 FstClass *FstClass::Read(std::istream &istrm, const string &source) {
59   return ReadFst<FstClass>(istrm, source);
60 }
61
62 FstClass *FstClass::ReadFromString(const string &fst_string) {
63   std::istringstream istrm(fst_string);
64   return ReadFst<FstClass>(istrm, "StringToFst");
65 }
66
67 const string FstClass::WriteToString() const {
68   std::ostringstream ostrm;
69   Write(ostrm, FstWriteOptions("StringToFst"));
70   return ostrm.str();
71 }
72
73 bool FstClass::WeightTypesMatch(const WeightClass &weight,
74                                 const string &op_name) const {
75   if (WeightType() != weight.Type()) {
76     FSTERROR() << "FST and weight with non-matching weight types passed to "
77                << op_name << ": " << WeightType() << " and " << weight.Type();
78     return false;
79   }
80   return true;
81 }
82
83 // MutableFstClass methods.
84
85 MutableFstClass *MutableFstClass::Read(const string &fname, bool convert) {
86   if (convert == false) {
87     if (!fname.empty()) {
88       std::ifstream in(fname, std::ios_base::in | std::ios_base::binary);
89       return ReadFst<MutableFstClass>(in, fname);
90     } else {
91       return ReadFst<MutableFstClass>(std::cin, "standard input");
92     }
93   } else {  // Converts to VectorFstClass if not mutable.
94     FstClass *ifst = FstClass::Read(fname);
95     if (!ifst) return nullptr;
96     if (ifst->Properties(fst::kMutable, false)) {
97       return static_cast<MutableFstClass *>(ifst);
98     } else {
99       MutableFstClass *ofst = new VectorFstClass(*ifst);
100       delete ifst;
101       return ofst;
102     }
103   }
104 }
105
106 // VectorFstClass methods.
107
108 VectorFstClass *VectorFstClass::Read(const string &fname) {
109   if (!fname.empty()) {
110     std::ifstream in(fname, std::ios_base::in | std::ios_base::binary);
111     return ReadFst<VectorFstClass>(in, fname);
112   } else {
113     return ReadFst<VectorFstClass>(std::cin, "standard input");
114   }
115 }
116
117 IORegistration<VectorFstClass>::Entry GetVFSTRegisterEntry(
118     const string &arc_type) {
119   return IORegistration<VectorFstClass>::Register::GetRegister()->GetEntry(
120       arc_type);
121 }
122
123 VectorFstClass::VectorFstClass(const string &arc_type)
124     : MutableFstClass(GetVFSTRegisterEntry(arc_type).creator()) {
125   if (Properties(kError, true) == kError) {
126     FSTERROR() << "VectorFstClass: Unknown arc type: " << arc_type;
127   }
128 }
129
130 VectorFstClass::VectorFstClass(const FstClass &other)
131     : MutableFstClass(GetVFSTRegisterEntry(other.ArcType()).converter(other)) {}
132
133 }  // namespace script
134 }  // namespace fst