1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 #ifndef FST_EXTENSIONS_SPECIAL_RHO_FST_H_
5 #define FST_EXTENSIONS_SPECIAL_RHO_FST_H_
10 #include <fst/const-fst.h>
11 #include <fst/matcher-fst.h>
12 #include <fst/matcher.h>
14 DECLARE_int64(rho_fst_rho_label);
15 DECLARE_string(rho_fst_rewrite_mode);
20 template <class Label>
21 class RhoFstMatcherData {
23 explicit RhoFstMatcherData(
24 Label rho_label = FLAGS_rho_fst_rho_label,
25 MatcherRewriteMode rewrite_mode = RewriteMode(FLAGS_rho_fst_rewrite_mode))
26 : rho_label_(rho_label), rewrite_mode_(rewrite_mode) {}
28 RhoFstMatcherData(const RhoFstMatcherData &data)
29 : rho_label_(data.rho_label_), rewrite_mode_(data.rewrite_mode_) {}
31 static RhoFstMatcherData<Label> *Read(std::istream &istrm,
32 const FstReadOptions &read) {
33 auto *data = new RhoFstMatcherData<Label>();
34 ReadType(istrm, &data->rho_label_);
36 ReadType(istrm, &rewrite_mode);
37 data->rewrite_mode_ = static_cast<MatcherRewriteMode>(rewrite_mode);
41 bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const {
42 WriteType(ostrm, rho_label_);
43 WriteType(ostrm, static_cast<int32>(rewrite_mode_));
44 return !ostrm ? false : true;
47 Label RhoLabel() const { return rho_label_; }
49 MatcherRewriteMode RewriteMode() const { return rewrite_mode_; }
52 static MatcherRewriteMode RewriteMode(const string &mode) {
53 if (mode == "auto") return MATCHER_REWRITE_AUTO;
54 if (mode == "always") return MATCHER_REWRITE_ALWAYS;
55 if (mode == "never") return MATCHER_REWRITE_NEVER;
56 LOG(WARNING) << "RhoFst: Unknown rewrite mode: " << mode << ". "
57 << "Defaulting to auto.";
58 return MATCHER_REWRITE_AUTO;
62 MatcherRewriteMode rewrite_mode_;
65 } // namespace internal
67 constexpr uint8 kRhoFstMatchInput = 0x01; // Input matcher is RhoMatcher.
68 constexpr uint8 kRhoFstMatchOutput = 0x02; // Output matcher is RhoMatcher.
70 template <class M, uint8 flags = kRhoFstMatchInput | kRhoFstMatchOutput>
71 class RhoFstMatcher : public RhoMatcher<M> {
73 using FST = typename M::FST;
74 using Arc = typename M::Arc;
75 using StateId = typename Arc::StateId;
76 using Label = typename Arc::Label;
77 using Weight = typename Arc::Weight;
78 using MatcherData = internal::RhoFstMatcherData<Label>;
80 enum : uint8 { kFlags = flags };
82 // This makes a copy of the FST.
84 const FST &fst, MatchType match_type,
85 std::shared_ptr<MatcherData> data = std::make_shared<MatcherData>())
86 : RhoMatcher<M>(fst, match_type,
87 RhoLabel(match_type, data ? data->RhoLabel()
88 : MatcherData().RhoLabel()),
89 data ? data->RewriteMode() : MatcherData().RewriteMode()),
92 // This doesn't copy the FST.
94 const FST *fst, MatchType match_type,
95 std::shared_ptr<MatcherData> data = std::make_shared<MatcherData>())
96 : RhoMatcher<M>(fst, match_type,
97 RhoLabel(match_type, data ? data->RhoLabel()
98 : MatcherData().RhoLabel()),
99 data ? data->RewriteMode() : MatcherData().RewriteMode()),
102 // This makes a copy of the FST.
103 RhoFstMatcher(const RhoFstMatcher<M, flags> &matcher, bool safe = false)
104 : RhoMatcher<M>(matcher, safe), data_(matcher.data_) {}
106 RhoFstMatcher<M, flags> *Copy(bool safe = false) const override {
107 return new RhoFstMatcher<M, flags>(*this, safe);
110 const MatcherData *GetData() const { return data_.get(); }
112 std::shared_ptr<MatcherData> GetSharedData() const { return data_; }
115 static Label RhoLabel(MatchType match_type, Label label) {
116 if (match_type == MATCH_INPUT && flags & kRhoFstMatchInput) return label;
117 if (match_type == MATCH_OUTPUT && flags & kRhoFstMatchOutput) return label;
121 std::shared_ptr<MatcherData> data_;
124 extern const char rho_fst_type[];
125 extern const char input_rho_fst_type[];
126 extern const char output_rho_fst_type[];
129 MatcherFst<ConstFst<StdArc>, RhoFstMatcher<SortedMatcher<ConstFst<StdArc>>>,
133 MatcherFst<ConstFst<LogArc>, RhoFstMatcher<SortedMatcher<ConstFst<LogArc>>>,
136 using Log64RhoFst = MatcherFst<ConstFst<Log64Arc>,
137 RhoFstMatcher<SortedMatcher<ConstFst<Log64Arc>>>,
140 using StdInputRhoFst =
141 MatcherFst<ConstFst<StdArc>, RhoFstMatcher<SortedMatcher<ConstFst<StdArc>>,
145 using LogInputRhoFst =
146 MatcherFst<ConstFst<LogArc>, RhoFstMatcher<SortedMatcher<ConstFst<LogArc>>,
150 using Log64InputRhoFst = MatcherFst<
152 RhoFstMatcher<SortedMatcher<ConstFst<Log64Arc>>, kRhoFstMatchInput>,
155 using StdOutputRhoFst =
156 MatcherFst<ConstFst<StdArc>, RhoFstMatcher<SortedMatcher<ConstFst<StdArc>>,
158 output_rho_fst_type>;
160 using LogOutputRhoFst =
161 MatcherFst<ConstFst<LogArc>, RhoFstMatcher<SortedMatcher<ConstFst<LogArc>>,
163 output_rho_fst_type>;
165 using Log64OutputRhoFst = MatcherFst<
167 RhoFstMatcher<SortedMatcher<ConstFst<Log64Arc>>, kRhoFstMatchOutput>,
168 output_rho_fst_type>;
172 #endif // FST_EXTENSIONS_SPECIAL_RHO_FST_H_