Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / extensions / special / rho-fst.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3
4 #ifndef FST_EXTENSIONS_SPECIAL_RHO_FST_H_
5 #define FST_EXTENSIONS_SPECIAL_RHO_FST_H_
6
7 #include <memory>
8 #include <string>
9
10 #include <fst/const-fst.h>
11 #include <fst/matcher-fst.h>
12 #include <fst/matcher.h>
13
14 DECLARE_int64(rho_fst_rho_label);
15 DECLARE_string(rho_fst_rewrite_mode);
16
17 namespace fst {
18 namespace internal {
19
20 template <class Label>
21 class RhoFstMatcherData {
22  public:
23   explicit RhoFstMatcherData(
24       Label rho_label = FLAGS_rho_fst_rho_label,
25       MatcherRewriteMode rewrite_mode = RewriteMode(FLAGS_rho_fst_rewrite_mode))
26       : rho_label_(rho_label), rewrite_mode_(rewrite_mode) {}
27
28   RhoFstMatcherData(const RhoFstMatcherData &data)
29       : rho_label_(data.rho_label_), rewrite_mode_(data.rewrite_mode_) {}
30
31   static RhoFstMatcherData<Label> *Read(std::istream &istrm,
32                                     const FstReadOptions &read) {
33     auto *data = new RhoFstMatcherData<Label>();
34     ReadType(istrm, &data->rho_label_);
35     int32 rewrite_mode;
36     ReadType(istrm, &rewrite_mode);
37     data->rewrite_mode_ = static_cast<MatcherRewriteMode>(rewrite_mode);
38     return data;
39   }
40
41   bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const {
42     WriteType(ostrm, rho_label_);
43     WriteType(ostrm, static_cast<int32>(rewrite_mode_));
44     return !ostrm ? false : true;
45   }
46
47   Label RhoLabel() const { return rho_label_; }
48
49   MatcherRewriteMode RewriteMode() const { return rewrite_mode_; }
50
51  private:
52   static MatcherRewriteMode RewriteMode(const string &mode) {
53     if (mode == "auto") return MATCHER_REWRITE_AUTO;
54     if (mode == "always") return MATCHER_REWRITE_ALWAYS;
55     if (mode == "never") return MATCHER_REWRITE_NEVER;
56     LOG(WARNING) << "RhoFst: Unknown rewrite mode: " << mode << ". "
57                  << "Defaulting to auto.";
58     return MATCHER_REWRITE_AUTO;
59   }
60
61   Label rho_label_;
62   MatcherRewriteMode rewrite_mode_;
63 };
64
65 }  // namespace internal
66
67 constexpr uint8 kRhoFstMatchInput = 0x01;   // Input matcher is RhoMatcher.
68 constexpr uint8 kRhoFstMatchOutput = 0x02;  // Output matcher is RhoMatcher.
69
70 template <class M, uint8 flags = kRhoFstMatchInput | kRhoFstMatchOutput>
71 class RhoFstMatcher : public RhoMatcher<M> {
72  public:
73   using FST = typename M::FST;
74   using Arc = typename M::Arc;
75   using StateId = typename Arc::StateId;
76   using Label = typename Arc::Label;
77   using Weight = typename Arc::Weight;
78   using MatcherData = internal::RhoFstMatcherData<Label>;
79
80   enum : uint8 { kFlags = flags };
81
82   // This makes a copy of the FST.
83   RhoFstMatcher(
84       const FST &fst, MatchType match_type,
85       std::shared_ptr<MatcherData> data = std::make_shared<MatcherData>())
86       : RhoMatcher<M>(fst, match_type,
87                       RhoLabel(match_type, data ? data->RhoLabel()
88                                                 : MatcherData().RhoLabel()),
89                       data ? data->RewriteMode() : MatcherData().RewriteMode()),
90         data_(data) {}
91
92   // This doesn't copy the FST.
93   RhoFstMatcher(
94       const FST *fst, MatchType match_type,
95       std::shared_ptr<MatcherData> data = std::make_shared<MatcherData>())
96       : RhoMatcher<M>(fst, match_type,
97                       RhoLabel(match_type, data ? data->RhoLabel()
98                                                 : MatcherData().RhoLabel()),
99                       data ? data->RewriteMode() : MatcherData().RewriteMode()),
100         data_(data) {}
101
102   // This makes a copy of the FST.
103   RhoFstMatcher(const RhoFstMatcher<M, flags> &matcher, bool safe = false)
104       : RhoMatcher<M>(matcher, safe), data_(matcher.data_) {}
105
106   RhoFstMatcher<M, flags> *Copy(bool safe = false) const override {
107     return new RhoFstMatcher<M, flags>(*this, safe);
108   }
109
110   const MatcherData *GetData() const { return data_.get(); }
111
112   std::shared_ptr<MatcherData> GetSharedData() const { return data_; }
113
114  private:
115   static Label RhoLabel(MatchType match_type, Label label) {
116     if (match_type == MATCH_INPUT && flags & kRhoFstMatchInput) return label;
117     if (match_type == MATCH_OUTPUT && flags & kRhoFstMatchOutput) return label;
118     return kNoLabel;
119   }
120
121   std::shared_ptr<MatcherData> data_;
122 };
123
124 extern const char rho_fst_type[];
125 extern const char input_rho_fst_type[];
126 extern const char output_rho_fst_type[];
127
128 using StdRhoFst =
129     MatcherFst<ConstFst<StdArc>, RhoFstMatcher<SortedMatcher<ConstFst<StdArc>>>,
130                rho_fst_type>;
131
132 using LogRhoFst =
133     MatcherFst<ConstFst<LogArc>, RhoFstMatcher<SortedMatcher<ConstFst<LogArc>>>,
134                rho_fst_type>;
135
136 using Log64RhoFst = MatcherFst<ConstFst<Log64Arc>,
137                                RhoFstMatcher<SortedMatcher<ConstFst<Log64Arc>>>,
138                                input_rho_fst_type>;
139
140 using StdInputRhoFst =
141     MatcherFst<ConstFst<StdArc>, RhoFstMatcher<SortedMatcher<ConstFst<StdArc>>,
142                                                kRhoFstMatchInput>,
143                input_rho_fst_type>;
144
145 using LogInputRhoFst =
146     MatcherFst<ConstFst<LogArc>, RhoFstMatcher<SortedMatcher<ConstFst<LogArc>>,
147                                                kRhoFstMatchInput>,
148                input_rho_fst_type>;
149
150 using Log64InputRhoFst = MatcherFst<
151     ConstFst<Log64Arc>,
152     RhoFstMatcher<SortedMatcher<ConstFst<Log64Arc>>, kRhoFstMatchInput>,
153     input_rho_fst_type>;
154
155 using StdOutputRhoFst =
156     MatcherFst<ConstFst<StdArc>, RhoFstMatcher<SortedMatcher<ConstFst<StdArc>>,
157                                                kRhoFstMatchOutput>,
158                output_rho_fst_type>;
159
160 using LogOutputRhoFst =
161     MatcherFst<ConstFst<LogArc>, RhoFstMatcher<SortedMatcher<ConstFst<LogArc>>,
162                                                kRhoFstMatchOutput>,
163                output_rho_fst_type>;
164
165 using Log64OutputRhoFst = MatcherFst<
166     ConstFst<Log64Arc>,
167     RhoFstMatcher<SortedMatcher<ConstFst<Log64Arc>>, kRhoFstMatchOutput>,
168     output_rho_fst_type>;
169
170 }  // namespace fst
171
172 #endif  // FST_EXTENSIONS_SPECIAL_RHO_FST_H_