Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / extensions / special / phi-fst.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3
4 #ifndef FST_EXTENSIONS_SPECIAL_PHI_FST_H_
5 #define FST_EXTENSIONS_SPECIAL_PHI_FST_H_
6
7 #include <memory>
8 #include <string>
9
10 #include <fst/const-fst.h>
11 #include <fst/matcher-fst.h>
12 #include <fst/matcher.h>
13
14 DECLARE_int64(phi_fst_phi_label);
15 DECLARE_bool(phi_fst_phi_loop);
16 DECLARE_string(phi_fst_rewrite_mode);
17
18 namespace fst {
19 namespace internal {
20
21 template <class Label>
22 class PhiFstMatcherData {
23  public:
24   PhiFstMatcherData(
25       Label phi_label = FLAGS_phi_fst_phi_label,
26       bool phi_loop = FLAGS_phi_fst_phi_loop,
27       MatcherRewriteMode rewrite_mode = RewriteMode(FLAGS_phi_fst_rewrite_mode))
28       : phi_label_(phi_label),
29         phi_loop_(phi_loop),
30         rewrite_mode_(rewrite_mode) {}
31
32   PhiFstMatcherData(const PhiFstMatcherData &data)
33       : phi_label_(data.phi_label_),
34         phi_loop_(data.phi_loop_),
35         rewrite_mode_(data.rewrite_mode_) {}
36
37   static PhiFstMatcherData<Label> *Read(std::istream &istrm,
38                                         const FstReadOptions &read) {
39     auto *data = new PhiFstMatcherData<Label>();
40     ReadType(istrm, &data->phi_label_);
41     ReadType(istrm, &data->phi_loop_);
42     int32 rewrite_mode;
43     ReadType(istrm, &rewrite_mode);
44     data->rewrite_mode_ = static_cast<MatcherRewriteMode>(rewrite_mode);
45     return data;
46   }
47
48   bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const {
49     WriteType(ostrm, phi_label_);
50     WriteType(ostrm, phi_loop_);
51     WriteType(ostrm, static_cast<int32>(rewrite_mode_));
52     return !ostrm ? false : true;
53   }
54
55   Label PhiLabel() const { return phi_label_; }
56
57   bool PhiLoop() const { return phi_loop_; }
58
59   MatcherRewriteMode RewriteMode() const { return rewrite_mode_; }
60
61  private:
62   static MatcherRewriteMode RewriteMode(const string &mode) {
63     if (mode == "auto") return MATCHER_REWRITE_AUTO;
64     if (mode == "always") return MATCHER_REWRITE_ALWAYS;
65     if (mode == "never") return MATCHER_REWRITE_NEVER;
66     LOG(WARNING) << "PhiFst: Unknown rewrite mode: " << mode << ". "
67                  << "Defaulting to auto.";
68     return MATCHER_REWRITE_AUTO;
69   }
70
71   Label phi_label_;
72   bool phi_loop_;
73   MatcherRewriteMode rewrite_mode_;
74 };
75
76 }  // namespace internal
77
78 constexpr uint8 kPhiFstMatchInput = 0x01;   // Input matcher is PhiMatcher.
79 constexpr uint8 kPhiFstMatchOutput = 0x02;  // Output matcher is PhiMatcher.
80
81 template <class M, uint8 flags = kPhiFstMatchInput | kPhiFstMatchOutput>
82 class PhiFstMatcher : public PhiMatcher<M> {
83  public:
84   using FST = typename M::FST;
85   using Arc = typename M::Arc;
86   using StateId = typename Arc::StateId;
87   using Label = typename Arc::Label;
88   using Weight = typename Arc::Weight;
89   using MatcherData = internal::PhiFstMatcherData<Label>;
90
91   enum : uint8 { kFlags = flags };
92
93   // This makes a copy of the FST.
94   PhiFstMatcher(const FST &fst, MatchType match_type,
95       std::shared_ptr<MatcherData> data = std::make_shared<MatcherData>())
96       : PhiMatcher<M>(fst, match_type,
97                       PhiLabel(match_type, data ? data->PhiLabel()
98                                                 : MatcherData().PhiLabel()),
99                       data ? data->PhiLoop() : MatcherData().PhiLoop(),
100                       data ? data->RewriteMode() : MatcherData().RewriteMode()),
101         data_(data) {}
102
103   // This doesn't copy the FST.
104   PhiFstMatcher(const FST *fst, MatchType match_type,
105       std::shared_ptr<MatcherData> data = std::make_shared<MatcherData>())
106       : PhiMatcher<M>(fst, match_type,
107                       PhiLabel(match_type, data ? data->PhiLabel()
108                                                 : MatcherData().PhiLabel()),
109                       data ? data->PhiLoop() : MatcherData().PhiLoop(),
110                       data ? data->RewriteMode() : MatcherData().RewriteMode()),
111         data_(data) {}
112
113   // This makes a copy of the FST.
114   PhiFstMatcher(const PhiFstMatcher<M, flags> &matcher, bool safe = false)
115       : PhiMatcher<M>(matcher, safe), data_(matcher.data_) {}
116
117   PhiFstMatcher<M, flags> *Copy(bool safe = false) const override {
118     return new PhiFstMatcher<M, flags>(*this, safe);
119   }
120
121   const MatcherData *GetData() const { return data_.get(); }
122
123   std::shared_ptr<MatcherData> GetSharedData() const { return data_; }
124
125  private:
126   static Label PhiLabel(MatchType match_type, Label label) {
127     if (match_type == MATCH_INPUT && flags & kPhiFstMatchInput) return label;
128     if (match_type == MATCH_OUTPUT && flags & kPhiFstMatchOutput) return label;
129     return kNoLabel;
130   }
131
132   std::shared_ptr<MatcherData> data_;
133 };
134
135 extern const char phi_fst_type[];
136 extern const char input_phi_fst_type[];
137 extern const char output_phi_fst_type[];
138
139 using StdPhiFst =
140     MatcherFst<ConstFst<StdArc>, PhiFstMatcher<SortedMatcher<ConstFst<StdArc>>>,
141                phi_fst_type>;
142
143 using LogPhiFst =
144     MatcherFst<ConstFst<LogArc>, PhiFstMatcher<SortedMatcher<ConstFst<LogArc>>>,
145                phi_fst_type>;
146
147 using Log64PhiFst = MatcherFst<ConstFst<Log64Arc>,
148                                PhiFstMatcher<SortedMatcher<ConstFst<Log64Arc>>>,
149                                input_phi_fst_type>;
150
151 using StdInputPhiFst =
152     MatcherFst<ConstFst<StdArc>, PhiFstMatcher<SortedMatcher<ConstFst<StdArc>>,
153                                                kPhiFstMatchInput>,
154                input_phi_fst_type>;
155
156 using LogInputPhiFst =
157     MatcherFst<ConstFst<LogArc>, PhiFstMatcher<SortedMatcher<ConstFst<LogArc>>,
158                                                kPhiFstMatchInput>,
159                input_phi_fst_type>;
160
161 using Log64InputPhiFst = MatcherFst<
162     ConstFst<Log64Arc>,
163     PhiFstMatcher<SortedMatcher<ConstFst<Log64Arc>>, kPhiFstMatchInput>,
164     input_phi_fst_type>;
165
166 using StdOutputPhiFst =
167     MatcherFst<ConstFst<StdArc>, PhiFstMatcher<SortedMatcher<ConstFst<StdArc>>,
168                                                kPhiFstMatchOutput>,
169                output_phi_fst_type>;
170
171 using LogOutputPhiFst =
172     MatcherFst<ConstFst<LogArc>, PhiFstMatcher<SortedMatcher<ConstFst<LogArc>>,
173                                                kPhiFstMatchOutput>,
174                output_phi_fst_type>;
175
176 using Log64OutputPhiFst = MatcherFst<
177     ConstFst<Log64Arc>,
178     PhiFstMatcher<SortedMatcher<ConstFst<Log64Arc>>, kPhiFstMatchOutput>,
179     output_phi_fst_type>;
180
181 }  // namespace fst
182
183 #endif  // FST_EXTENSIONS_SPECIAL_PHI_FST_H_