7aad31eba6bc9fb4241d9f56c171b80ca0afd955
[platform/upstream/openfst.git] / src / include / fst / matcher-fst.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to add a matcher to an FST.
5
6 #ifndef FST_MATCHER_FST_H_
7 #define FST_MATCHER_FST_H_
8
9 #include <memory>
10 #include <string>
11
12 #include <fst/add-on.h>
13 #include <fst/const-fst.h>
14 #include <fst/lookahead-matcher.h>
15
16
17 namespace fst {
18
19 // Writeable matchers have the same interface as Matchers (as defined in
20 // matcher.h) along with the following additional methods:
21 //
22 // template <class F>
23 // class Matcher {
24 //  public:
25 //   using FST = F;
26 //   ...
27 //   using MatcherData = ...;   // Initialization data.
28 //
29 //   // Constructor with additional argument for external initialization data;
30 //   // matcher increments its reference count on construction and decrements
31 //   // the reference count, and deletes once the reference count has reached
32 //   // zero.
33 //   Matcher(const FST &fst, MatchType type, MatcherData *data);
34 //
35 //   // Returns pointer to initialization data that can be passed to a Matcher
36 //   // constructor.
37 //   MatcherData *GetData() const;
38 // };
39
40 // The matcher initialization data class must also provide the following
41 // interface:
42 //
43 // class MatcherData {
44 // public:
45 //   // Required copy constructor.
46 //   MatcherData(const MatcherData &);
47 //
48 //   // Required I/O methods.
49 //   static MatcherData *Read(std::istream &istrm, const FstReadOptions &opts);
50 //   bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const;
51 // };
52
53 // Trivial (no-op) MatcherFst initializer functor.
54 template <class M>
55 class NullMatcherFstInit {
56  public:
57   using MatcherData = typename M::MatcherData;
58   using Data = AddOnPair<MatcherData, MatcherData>;
59   using Impl = internal::AddOnImpl<typename M::FST, Data>;
60
61   explicit NullMatcherFstInit(std::shared_ptr<Impl> *) {}
62 };
63
64 // Class adding a matcher to an FST type. Creates a new FST whose name is given
65 // by N. An optional functor Init can be used to initialize the FST. The Data
66 // template parameter allows the user to select the type of the add-on.
67 template <
68     class F, class M, const char *Name, class Init = NullMatcherFstInit<M>,
69     class Data = AddOnPair<typename M::MatcherData, typename M::MatcherData>>
70 class MatcherFst : public ImplToExpandedFst<internal::AddOnImpl<F, Data>> {
71  public:
72   using FST = F;
73   using Arc = typename FST::Arc;
74   using StateId = typename Arc::StateId;
75
76   using FstMatcher = M;
77   using MatcherData = typename FstMatcher::MatcherData;
78
79   using Impl = internal::AddOnImpl<FST, Data>;
80   using D = Data;
81
82   friend class StateIterator<MatcherFst<FST, FstMatcher, Name, Init, Data>>;
83   friend class ArcIterator<MatcherFst<FST, FstMatcher, Name, Init, Data>>;
84
85   MatcherFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>(FST(), Name)) {}
86
87   explicit MatcherFst(const FST &fst, std::shared_ptr<Data> data = nullptr)
88       : ImplToExpandedFst<Impl>(data ? CreateImpl(fst, Name, data)
89                                      : CreateDataAndImpl(fst, Name)) {}
90
91   explicit MatcherFst(const Fst<Arc> &fst)
92       : ImplToExpandedFst<Impl>(CreateDataAndImpl(fst, Name)) {}
93
94   // See Fst<>::Copy() for doc.
95   MatcherFst(const MatcherFst<FST, FstMatcher, Name, Init, Data> &fst,
96              bool safe = false)
97       : ImplToExpandedFst<Impl>(fst, safe) {}
98
99   // Get a copy of this MatcherFst. See Fst<>::Copy() for further doc.
100   MatcherFst<FST, FstMatcher, Name, Init, Data> *Copy(
101       bool safe = false) const override {
102     return new MatcherFst<FST, FstMatcher, Name, Init, Data>(*this, safe);
103   }
104
105   // Read a MatcherFst from an input stream; return nullptr on error
106   static MatcherFst<FST, M, Name, Init, Data> *Read(
107       std::istream &strm, const FstReadOptions &opts) {
108     auto *impl = Impl::Read(strm, opts);
109     return impl ? new MatcherFst<FST, FstMatcher, Name, Init, Data>(
110                       std::shared_ptr<Impl>(impl))
111                 : nullptr;
112   }
113
114   // Read a MatcherFst from a file; return nullptr on error
115   // Empty filename reads from standard input
116   static MatcherFst<FST, FstMatcher, Name, Init, Data> *Read(
117       const string &filename) {
118     auto *impl = ImplToExpandedFst<Impl>::Read(filename);
119     return impl ? new MatcherFst<FST, FstMatcher, Name, Init, Data>(
120                       std::shared_ptr<Impl>(impl))
121                 : nullptr;
122   }
123
124   bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
125     return GetImpl()->Write(strm, opts);
126   }
127
128   bool Write(const string &filename) const override {
129     return Fst<Arc>::WriteFile(filename);
130   }
131
132   void InitStateIterator(StateIteratorData<Arc> *data) const override {
133     return GetImpl()->InitStateIterator(data);
134   }
135
136   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
137     return GetImpl()->InitArcIterator(s, data);
138   }
139
140   FstMatcher *InitMatcher(MatchType match_type) const override {
141     return new FstMatcher(GetFst(), match_type, GetSharedData(match_type));
142   }
143
144   const FST &GetFst() const { return GetImpl()->GetFst(); }
145
146   const Data *GetAddOn() const { return GetImpl()->GetAddOn(); }
147
148   std::shared_ptr<Data> GetSharedAddOn() const {
149     return GetImpl()->GetSharedAddOn();
150   }
151
152   const MatcherData *GetData(MatchType match_type) const {
153     const auto *data = GetAddOn();
154     return match_type == MATCH_INPUT ? data->First() : data->Second();
155   }
156
157   std::shared_ptr<MatcherData> GetSharedData(MatchType match_type) const {
158     const auto *data = GetAddOn();
159     return match_type == MATCH_INPUT ? data->SharedFirst()
160                                      : data->SharedSecond();
161   }
162
163  protected:
164   using ImplToFst<Impl, ExpandedFst<Arc>>::GetImpl;
165
166   static std::shared_ptr<Impl> CreateDataAndImpl(const FST &fst,
167                                                  const string &name) {
168     FstMatcher imatcher(fst, MATCH_INPUT);
169     FstMatcher omatcher(fst, MATCH_OUTPUT);
170     return CreateImpl(fst, name,
171                       std::make_shared<Data>(imatcher.GetSharedData(),
172                                              omatcher.GetSharedData()));
173   }
174
175   static std::shared_ptr<Impl> CreateDataAndImpl(const Fst<Arc> &fst,
176                                                  const string &name) {
177     FST result(fst);
178     return CreateDataAndImpl(result, name);
179   }
180
181   static std::shared_ptr<Impl> CreateImpl(const FST &fst, const string &name,
182                                           std::shared_ptr<Data> data) {
183     auto impl = std::make_shared<Impl>(fst, name);
184     impl->SetAddOn(data);
185     Init init(&impl);
186     return impl;
187   }
188
189   explicit MatcherFst(std::shared_ptr<Impl> impl)
190       : ImplToExpandedFst<Impl>(impl) {}
191
192  private:
193   MatcherFst &operator=(const MatcherFst &) = delete;
194 };
195
196 // Specialization for MatcherFst.
197 template <class FST, class M, const char *Name, class Init>
198 class StateIterator<MatcherFst<FST, M, Name, Init>>
199     : public StateIterator<FST> {
200  public:
201   explicit StateIterator(const MatcherFst<FST, M, Name, Init> &fst)
202       : StateIterator<FST>(fst.GetImpl()->GetFst()) {}
203 };
204
205 // Specialization for MatcherFst.
206 template <class FST, class M, const char *Name, class Init>
207 class ArcIterator<MatcherFst<FST, M, Name, Init>> : public ArcIterator<FST> {
208  public:
209   using StateId = typename FST::Arc::StateId;
210
211   ArcIterator(const MatcherFst<FST, M, Name, Init> &fst,
212               typename FST::Arc::StateId s)
213       : ArcIterator<FST>(fst.GetImpl()->GetFst(), s) {}
214 };
215
216 // Specialization for MatcherFst.
217 template <class F, class M, const char *Name, class Init>
218 class Matcher<MatcherFst<F, M, Name, Init>> {
219  public:
220   using FST = MatcherFst<F, M, Name, Init>;
221   using Arc = typename F::Arc;
222   using Label = typename Arc::Label;
223   using StateId = typename Arc::StateId;
224
225   Matcher(const FST &fst, MatchType match_type)
226       : matcher_(fst.InitMatcher(match_type)) {}
227
228   Matcher(const Matcher<FST> &matcher) : matcher_(matcher.matcher_->Copy()) {}
229
230   Matcher<FST> *Copy() const { return new Matcher<FST>(*this); }
231
232   MatchType Type(bool test) const { return matcher_->Type(test); }
233
234   void SetState(StateId s) { matcher_->SetState(s); }
235
236   bool Find(Label label) { return matcher_->Find(label); }
237
238   bool Done() const { return matcher_->Done(); }
239
240   const Arc &Value() const { return matcher_->Value(); }
241
242   void Next() { matcher_->Next(); }
243
244   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
245
246   uint32 Flags() const { return matcher_->Flags(); }
247
248  private:
249   std::unique_ptr<M> matcher_;
250 };
251
252 // Specialization for MatcherFst.
253 template <class F, class M, const char *Name, class Init>
254 class LookAheadMatcher<MatcherFst<F, M, Name, Init>> {
255  public:
256   using FST = MatcherFst<F, M, Name, Init>;
257   using Arc = typename F::Arc;
258   using Label = typename Arc::Label;
259   using StateId = typename Arc::StateId;
260   using Weight = typename Arc::Weight;
261
262   LookAheadMatcher(const FST &fst, MatchType match_type)
263       : matcher_(fst.InitMatcher(match_type)) {}
264
265   LookAheadMatcher(const LookAheadMatcher<FST> &matcher, bool safe = false)
266       : matcher_(matcher.matcher_->Copy(safe)) {}
267
268   // General matcher methods.
269   LookAheadMatcher<FST> *Copy(bool safe = false) const {
270     return new LookAheadMatcher<FST>(*this, safe);
271   }
272
273   MatchType Type(bool test) const { return matcher_->Type(test); }
274
275   void SetState(StateId s) { matcher_->SetState(s); }
276
277   bool Find(Label label) { return matcher_->Find(label); }
278
279   bool Done() const { return matcher_->Done(); }
280
281   const Arc &Value() const { return matcher_->Value(); }
282
283   void Next() { matcher_->Next(); }
284
285   const FST &GetFst() const { return matcher_->GetFst(); }
286
287   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
288
289   uint32 Flags() const { return matcher_->Flags(); }
290
291   bool LookAheadLabel(Label label) const {
292     return matcher_->LookAheadLabel(label);
293   }
294
295   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
296     return matcher_->LookAheadFst(fst, s);
297   }
298
299   Weight LookAheadWeight() const { return matcher_->LookAheadWeight(); }
300
301   bool LookAheadPrefix(Arc *arc) const {
302     return matcher_->LookAheadPrefix(arc);
303   }
304
305   void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
306     matcher_->InitLookAheadFst(fst, copy);
307   }
308
309  private:
310   std::unique_ptr<M> matcher_;
311 };
312
313 // Useful aliases when using StdArc.
314
315 extern const char arc_lookahead_fst_type[];
316
317 using StdArcLookAheadFst =
318     MatcherFst<ConstFst<StdArc>,
319                ArcLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>>,
320                arc_lookahead_fst_type>;
321
322 extern const char ilabel_lookahead_fst_type[];
323 extern const char olabel_lookahead_fst_type[];
324
325 constexpr auto ilabel_lookahead_flags =
326     kInputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix |
327     kLookAheadEpsilons | kLookAheadNonEpsilonPrefix;
328
329 constexpr auto olabel_lookahead_flags =
330     kOutputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix |
331     kLookAheadEpsilons | kLookAheadNonEpsilonPrefix;
332
333 using StdILabelLookAheadFst = MatcherFst<
334     ConstFst<StdArc>,
335     LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>,
336                           ilabel_lookahead_flags, FastLogAccumulator<StdArc>>,
337     ilabel_lookahead_fst_type, LabelLookAheadRelabeler<StdArc>>;
338
339 using StdOLabelLookAheadFst = MatcherFst<
340     ConstFst<StdArc>,
341     LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>,
342                           olabel_lookahead_flags, FastLogAccumulator<StdArc>>,
343     olabel_lookahead_fst_type, LabelLookAheadRelabeler<StdArc>>;
344
345 }  // namespace fst
346
347 #endif  // FST_MATCHER_FST_H_