1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Class to add a matcher to an FST.
6 #ifndef FST_LIB_MATCHER_FST_H_
7 #define FST_LIB_MATCHER_FST_H_
12 #include <fst/add-on.h>
13 #include <fst/const-fst.h>
14 #include <fst/lookahead-matcher.h>
19 // Writeable matchers have the same interface as Matchers (as defined in
20 // matcher.h) along with the following additional methods:
27 // using MatcherData = ...; // Initialization data.
29 // // Constructor with additional argument for external initialization data;
30 // // matcher increments its reference count on construction and decrements
31 // // the reference count, and deletes once the reference count has reached
33 // Matcher(const FST &fst, MatchType type, MatcherData *data);
35 // // Returns pointer to initialization data that can be passed to a Matcher
37 // MatcherData *GetData() const;
40 // The matcher initialization data class must also provide the following
43 // class MatcherData {
45 // // Required copy constructor.
46 // MatcherData(const MatcherData &);
48 // // Required I/O methods.
49 // static MatcherData *Read(std::istream &istrm, const FstReadOptions &opts);
50 // bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const;
53 // Trivial (no-op) MatcherFst initializer functor.
55 class NullMatcherFstInit {
57 using MatcherData = typename M::MatcherData;
58 using Data = AddOnPair<MatcherData, MatcherData>;
59 using Impl = internal::AddOnImpl<typename M::FST, Data>;
61 explicit NullMatcherFstInit(std::shared_ptr<Impl> *) {}
64 // Class adding a matcher to an FST type. Creates a new FST whose name is given
65 // by N. An optional functor Init can be used to initialize the FST. The Data
66 // template parameter allows the user to select the type of the add-on.
68 class F, class M, const char *Name, class Init = NullMatcherFstInit<M>,
69 class Data = AddOnPair<typename M::MatcherData, typename M::MatcherData>>
70 class MatcherFst : public ImplToExpandedFst<internal::AddOnImpl<F, Data>> {
73 using Arc = typename FST::Arc;
74 using StateId = typename Arc::StateId;
77 using MatcherData = typename FstMatcher::MatcherData;
79 using Impl = internal::AddOnImpl<FST, Data>;
82 friend class StateIterator<MatcherFst<FST, FstMatcher, Name, Init, Data>>;
83 friend class ArcIterator<MatcherFst<FST, FstMatcher, Name, Init, Data>>;
85 MatcherFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>(FST(), Name)) {}
87 explicit MatcherFst(const FST &fst, std::shared_ptr<Data> data = nullptr)
88 : ImplToExpandedFst<Impl>(data ? CreateImpl(fst, Name, data)
89 : CreateDataAndImpl(fst, Name)) {}
91 explicit MatcherFst(const Fst<Arc> &fst)
92 : ImplToExpandedFst<Impl>(CreateDataAndImpl(fst, Name)) {}
94 // See Fst<>::Copy() for doc.
95 MatcherFst(const MatcherFst<FST, FstMatcher, Name, Init, Data> &fst,
97 : ImplToExpandedFst<Impl>(fst, safe) {}
99 // Get a copy of this MatcherFst. See Fst<>::Copy() for further doc.
100 MatcherFst<FST, FstMatcher, Name, Init, Data> *Copy(
101 bool safe = false) const override {
102 return new MatcherFst<FST, FstMatcher, Name, Init, Data>(*this, safe);
105 // Read a MatcherFst from an input stream; return nullptr on error
106 static MatcherFst<FST, M, Name, Init, Data> *Read(
107 std::istream &strm, const FstReadOptions &opts) {
108 auto *impl = Impl::Read(strm, opts);
109 return impl ? new MatcherFst<FST, FstMatcher, Name, Init, Data>(
110 std::shared_ptr<Impl>(impl))
114 // Read a MatcherFst from a file; return nullptr on error
115 // Empty filename reads from standard input
116 static MatcherFst<FST, FstMatcher, Name, Init, Data> *Read(
117 const string &filename) {
118 auto *impl = ImplToExpandedFst<Impl>::Read(filename);
119 return impl ? new MatcherFst<FST, FstMatcher, Name, Init, Data>(
120 std::shared_ptr<Impl>(impl))
124 bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
125 return GetImpl()->Write(strm, opts);
128 bool Write(const string &filename) const override {
129 return Fst<Arc>::WriteFile(filename);
132 void InitStateIterator(StateIteratorData<Arc> *data) const override {
133 return GetImpl()->InitStateIterator(data);
136 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
137 return GetImpl()->InitArcIterator(s, data);
140 FstMatcher *InitMatcher(MatchType match_type) const override {
141 return new FstMatcher(GetFst(), match_type, GetSharedData(match_type));
144 const FST &GetFst() const { return GetImpl()->GetFst(); }
146 const Data *GetAddOn() const { return GetImpl()->GetAddOn(); }
148 std::shared_ptr<Data> GetSharedAddOn() const {
149 return GetImpl()->GetSharedAddOn();
152 const MatcherData *GetData(MatchType match_type) const {
153 const auto *data = GetAddOn();
154 return match_type == MATCH_INPUT ? data->First() : data->Second();
157 std::shared_ptr<MatcherData> GetSharedData(MatchType match_type) const {
158 const auto *data = GetAddOn();
159 return match_type == MATCH_INPUT ? data->SharedFirst()
160 : data->SharedSecond();
164 using ImplToFst<Impl, ExpandedFst<Arc>>::GetImpl;
166 static std::shared_ptr<Impl> CreateDataAndImpl(const FST &fst,
167 const string &name) {
168 FstMatcher imatcher(fst, MATCH_INPUT);
169 FstMatcher omatcher(fst, MATCH_OUTPUT);
170 return CreateImpl(fst, name,
171 std::make_shared<Data>(imatcher.GetSharedData(),
172 omatcher.GetSharedData()));
175 static std::shared_ptr<Impl> CreateDataAndImpl(const Fst<Arc> &fst,
176 const string &name) {
178 return CreateDataAndImpl(result, name);
181 static std::shared_ptr<Impl> CreateImpl(const FST &fst, const string &name,
182 std::shared_ptr<Data> data) {
184 auto impl = std::make_shared<Impl>(fst, name);
185 impl->SetAddOn(data);
190 explicit MatcherFst(std::shared_ptr<Impl> impl)
191 : ImplToExpandedFst<Impl>(impl) {}
194 MatcherFst &operator=(const MatcherFst &) = delete;
197 // Specialization for MatcherFst.
198 template <class FST, class M, const char *Name, class Init>
199 class StateIterator<MatcherFst<FST, M, Name, Init>>
200 : public StateIterator<FST> {
202 explicit StateIterator(const MatcherFst<FST, M, Name, Init> &fst)
203 : StateIterator<FST>(fst.GetImpl()->GetFst()) {}
206 // Specialization for MatcherFst.
207 template <class FST, class M, const char *Name, class Init>
208 class ArcIterator<MatcherFst<FST, M, Name, Init>> : public ArcIterator<FST> {
210 using StateId = typename FST::Arc::StateId;
212 ArcIterator(const MatcherFst<FST, M, Name, Init> &fst,
213 typename FST::Arc::StateId s)
214 : ArcIterator<FST>(fst.GetImpl()->GetFst(), s) {}
217 // Specialization for MatcherFst.
218 template <class F, class M, const char *Name, class Init>
219 class Matcher<MatcherFst<F, M, Name, Init>> {
221 using FST = MatcherFst<F, M, Name, Init>;
222 using Arc = typename F::Arc;
223 using Label = typename Arc::Label;
224 using StateId = typename Arc::StateId;
226 Matcher(const FST &fst, MatchType match_type)
227 : matcher_(fst.InitMatcher(match_type)) {}
229 Matcher(const Matcher<FST> &matcher) : matcher_(matcher.matcher_->Copy()) {}
231 Matcher<FST> *Copy() const { return new Matcher<FST>(*this); }
233 MatchType Type(bool test) const { return matcher_->Type(test); }
235 void SetState(StateId s) { matcher_->SetState(s); }
237 bool Find(Label label) { return matcher_->Find(label); }
239 bool Done() const { return matcher_->Done(); }
241 const Arc &Value() const { return matcher_->Value(); }
243 void Next() { matcher_->Next(); }
245 uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
247 uint32 Flags() const { return matcher_->Flags(); }
250 std::unique_ptr<M> matcher_;
253 // Specialization for MatcherFst.
254 template <class F, class M, const char *Name, class Init>
255 class LookAheadMatcher<MatcherFst<F, M, Name, Init>> {
257 using FST = MatcherFst<F, M, Name, Init>;
258 using Arc = typename F::Arc;
259 using Label = typename Arc::Label;
260 using StateId = typename Arc::StateId;
261 using Weight = typename Arc::Weight;
263 LookAheadMatcher(const FST &fst, MatchType match_type)
264 : matcher_(fst.InitMatcher(match_type)) {}
266 LookAheadMatcher(const LookAheadMatcher<FST> &matcher, bool safe = false)
267 : matcher_(matcher.matcher_->Copy(safe)) {}
269 // General matcher methods.
270 LookAheadMatcher<FST> *Copy(bool safe = false) const {
271 return new LookAheadMatcher<FST>(*this, safe);
274 MatchType Type(bool test) const { return matcher_->Type(test); }
276 void SetState(StateId s) { matcher_->SetState(s); }
278 bool Find(Label label) { return matcher_->Find(label); }
280 bool Done() const { return matcher_->Done(); }
282 const Arc &Value() const { return matcher_->Value(); }
284 void Next() { matcher_->Next(); }
286 const FST &GetFst() const { return matcher_->GetFst(); }
288 uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
290 uint32 Flags() const { return matcher_->Flags(); }
292 bool LookAheadLabel(Label label) const {
293 return matcher_->LookAheadLabel(label);
296 bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
297 return matcher_->LookAheadFst(fst, s);
300 Weight LookAheadWeight() const { return matcher_->LookAheadWeight(); }
302 bool LookAheadPrefix(Arc *arc) const {
303 return matcher_->LookAheadPrefix(arc);
306 void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
307 matcher_->InitLookAheadFst(fst, copy);
311 std::unique_ptr<M> matcher_;
314 // Useful aliases when using StdArc.
316 extern const char arc_lookahead_fst_type[];
318 using StdArcLookAheadFst =
319 MatcherFst<ConstFst<StdArc>,
320 ArcLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>>,
321 arc_lookahead_fst_type>;
323 extern const char ilabel_lookahead_fst_type[];
324 extern const char olabel_lookahead_fst_type[];
326 constexpr auto ilabel_lookahead_flags =
327 kInputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix |
328 kLookAheadEpsilons | kLookAheadNonEpsilonPrefix;
330 constexpr auto olabel_lookahead_flags =
331 kOutputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix |
332 kLookAheadEpsilons | kLookAheadNonEpsilonPrefix;
334 using StdILabelLookAheadFst = MatcherFst<
336 LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>,
337 ilabel_lookahead_flags, FastLogAccumulator<StdArc>>,
338 ilabel_lookahead_fst_type, LabelLookAheadRelabeler<StdArc>>;
340 using StdOLabelLookAheadFst = MatcherFst<
342 LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>,
343 olabel_lookahead_flags, FastLogAccumulator<StdArc>>,
344 olabel_lookahead_fst_type, LabelLookAheadRelabeler<StdArc>>;
348 #endif // FST_LIB_MATCHER_FST_H_