167eb956f0927ecc6418fdc3640e090ea1911e9b
[platform/upstream/openfst.git] / src / include / fst / matcher-fst.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to add a matcher to an FST.
5
6 #ifndef FST_LIB_MATCHER_FST_H_
7 #define FST_LIB_MATCHER_FST_H_
8
9 #include <memory>
10 #include <string>
11
12 #include <fst/add-on.h>
13 #include <fst/const-fst.h>
14 #include <fst/lookahead-matcher.h>
15
16
17 namespace fst {
18
19 // Writeable matchers have the same interface as Matchers (as defined in
20 // matcher.h) along with the following additional methods:
21 //
22 // template <class F>
23 // class Matcher {
24 //  public:
25 //   using FST = F;
26 //   ...
27 //   using MatcherData = ...;   // Initialization data.
28 //
29 //   // Constructor with additional argument for external initialization data;
30 //   // matcher increments its reference count on construction and decrements
31 //   // the reference count, and deletes once the reference count has reached
32 //   // zero.
33 //   Matcher(const FST &fst, MatchType type, MatcherData *data);
34 //
35 //   // Returns pointer to initialization data that can be passed to a Matcher
36 //   // constructor.
37 //   MatcherData *GetData() const;
38 // };
39
40 // The matcher initialization data class must also provide the following
41 // interface:
42 //
43 // class MatcherData {
44 // public:
45 //   // Required copy constructor.
46 //   MatcherData(const MatcherData &);
47 //
48 //   // Required I/O methods.
49 //   static MatcherData *Read(std::istream &istrm, const FstReadOptions &opts);
50 //   bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const;
51 // };
52
53 // Trivial (no-op) MatcherFst initializer functor.
54 template <class M>
55 class NullMatcherFstInit {
56  public:
57   using MatcherData = typename M::MatcherData;
58   using Data = AddOnPair<MatcherData, MatcherData>;
59   using Impl = internal::AddOnImpl<typename M::FST, Data>;
60
61   explicit NullMatcherFstInit(std::shared_ptr<Impl> *) {}
62 };
63
64 // Class adding a matcher to an FST type. Creates a new FST whose name is given
65 // by N. An optional functor Init can be used to initialize the FST. The Data
66 // template parameter allows the user to select the type of the add-on.
67 template <
68     class F, class M, const char *Name, class Init = NullMatcherFstInit<M>,
69     class Data = AddOnPair<typename M::MatcherData, typename M::MatcherData>>
70 class MatcherFst : public ImplToExpandedFst<internal::AddOnImpl<F, Data>> {
71  public:
72   using FST = F;
73   using Arc = typename FST::Arc;
74   using StateId = typename Arc::StateId;
75
76   using FstMatcher = M;
77   using MatcherData = typename FstMatcher::MatcherData;
78
79   using Impl = internal::AddOnImpl<FST, Data>;
80   using D = Data;
81
82   friend class StateIterator<MatcherFst<FST, FstMatcher, Name, Init, Data>>;
83   friend class ArcIterator<MatcherFst<FST, FstMatcher, Name, Init, Data>>;
84
85   MatcherFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>(FST(), Name)) {}
86
87   explicit MatcherFst(const FST &fst, std::shared_ptr<Data> data = nullptr)
88       : ImplToExpandedFst<Impl>(data ? CreateImpl(fst, Name, data)
89                                      : CreateDataAndImpl(fst, Name)) {}
90
91   explicit MatcherFst(const Fst<Arc> &fst)
92       : ImplToExpandedFst<Impl>(CreateDataAndImpl(fst, Name)) {}
93
94   // See Fst<>::Copy() for doc.
95   MatcherFst(const MatcherFst<FST, FstMatcher, Name, Init, Data> &fst,
96              bool safe = false)
97       : ImplToExpandedFst<Impl>(fst, safe) {}
98
99   // Get a copy of this MatcherFst. See Fst<>::Copy() for further doc.
100   MatcherFst<FST, FstMatcher, Name, Init, Data> *Copy(
101       bool safe = false) const override {
102     return new MatcherFst<FST, FstMatcher, Name, Init, Data>(*this, safe);
103   }
104
105   // Read a MatcherFst from an input stream; return nullptr on error
106   static MatcherFst<FST, M, Name, Init, Data> *Read(
107       std::istream &strm, const FstReadOptions &opts) {
108     auto *impl = Impl::Read(strm, opts);
109     return impl ? new MatcherFst<FST, FstMatcher, Name, Init, Data>(
110                       std::shared_ptr<Impl>(impl))
111                 : nullptr;
112   }
113
114   // Read a MatcherFst from a file; return nullptr on error
115   // Empty filename reads from standard input
116   static MatcherFst<FST, FstMatcher, Name, Init, Data> *Read(
117       const string &filename) {
118     auto *impl = ImplToExpandedFst<Impl>::Read(filename);
119     return impl ? new MatcherFst<FST, FstMatcher, Name, Init, Data>(
120                       std::shared_ptr<Impl>(impl))
121                 : nullptr;
122   }
123
124   bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
125     return GetImpl()->Write(strm, opts);
126   }
127
128   bool Write(const string &filename) const override {
129     return Fst<Arc>::WriteFile(filename);
130   }
131
132   void InitStateIterator(StateIteratorData<Arc> *data) const override {
133     return GetImpl()->InitStateIterator(data);
134   }
135
136   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
137     return GetImpl()->InitArcIterator(s, data);
138   }
139
140   FstMatcher *InitMatcher(MatchType match_type) const override {
141     return new FstMatcher(GetFst(), match_type, GetSharedData(match_type));
142   }
143
144   const FST &GetFst() const { return GetImpl()->GetFst(); }
145
146   const Data *GetAddOn() const { return GetImpl()->GetAddOn(); }
147
148   std::shared_ptr<Data> GetSharedAddOn() const {
149     return GetImpl()->GetSharedAddOn();
150   }
151
152   const MatcherData *GetData(MatchType match_type) const {
153     const auto *data = GetAddOn();
154     return match_type == MATCH_INPUT ? data->First() : data->Second();
155   }
156
157   std::shared_ptr<MatcherData> GetSharedData(MatchType match_type) const {
158     const auto *data = GetAddOn();
159     return match_type == MATCH_INPUT ? data->SharedFirst()
160                                      : data->SharedSecond();
161   }
162
163  protected:
164   using ImplToFst<Impl, ExpandedFst<Arc>>::GetImpl;
165
166   static std::shared_ptr<Impl> CreateDataAndImpl(const FST &fst,
167                                                  const string &name) {
168     FstMatcher imatcher(fst, MATCH_INPUT);
169     FstMatcher omatcher(fst, MATCH_OUTPUT);
170     return CreateImpl(fst, name,
171                       std::make_shared<Data>(imatcher.GetSharedData(),
172                                              omatcher.GetSharedData()));
173   }
174
175   static std::shared_ptr<Impl> CreateDataAndImpl(const Fst<Arc> &fst,
176                                                  const string &name) {
177     FST result(fst);
178     return CreateDataAndImpl(result, name);
179   }
180
181   static std::shared_ptr<Impl> CreateImpl(const FST &fst, const string &name,
182                                           std::shared_ptr<Data> data) {
183     CHECK(data);
184     auto impl = std::make_shared<Impl>(fst, name);
185     impl->SetAddOn(data);
186     Init init(&impl);
187     return impl;
188   }
189
190   explicit MatcherFst(std::shared_ptr<Impl> impl)
191       : ImplToExpandedFst<Impl>(impl) {}
192
193  private:
194   MatcherFst &operator=(const MatcherFst &) = delete;
195 };
196
197 // Specialization for MatcherFst.
198 template <class FST, class M, const char *Name, class Init>
199 class StateIterator<MatcherFst<FST, M, Name, Init>>
200     : public StateIterator<FST> {
201  public:
202   explicit StateIterator(const MatcherFst<FST, M, Name, Init> &fst)
203       : StateIterator<FST>(fst.GetImpl()->GetFst()) {}
204 };
205
206 // Specialization for MatcherFst.
207 template <class FST, class M, const char *Name, class Init>
208 class ArcIterator<MatcherFst<FST, M, Name, Init>> : public ArcIterator<FST> {
209  public:
210   using StateId = typename FST::Arc::StateId;
211
212   ArcIterator(const MatcherFst<FST, M, Name, Init> &fst,
213               typename FST::Arc::StateId s)
214       : ArcIterator<FST>(fst.GetImpl()->GetFst(), s) {}
215 };
216
217 // Specialization for MatcherFst.
218 template <class F, class M, const char *Name, class Init>
219 class Matcher<MatcherFst<F, M, Name, Init>> {
220  public:
221   using FST = MatcherFst<F, M, Name, Init>;
222   using Arc = typename F::Arc;
223   using Label = typename Arc::Label;
224   using StateId = typename Arc::StateId;
225
226   Matcher(const FST &fst, MatchType match_type)
227       : matcher_(fst.InitMatcher(match_type)) {}
228
229   Matcher(const Matcher<FST> &matcher) : matcher_(matcher.matcher_->Copy()) {}
230
231   Matcher<FST> *Copy() const { return new Matcher<FST>(*this); }
232
233   MatchType Type(bool test) const { return matcher_->Type(test); }
234
235   void SetState(StateId s) { matcher_->SetState(s); }
236
237   bool Find(Label label) { return matcher_->Find(label); }
238
239   bool Done() const { return matcher_->Done(); }
240
241   const Arc &Value() const { return matcher_->Value(); }
242
243   void Next() { matcher_->Next(); }
244
245   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
246
247   uint32 Flags() const { return matcher_->Flags(); }
248
249  private:
250   std::unique_ptr<M> matcher_;
251 };
252
253 // Specialization for MatcherFst.
254 template <class F, class M, const char *Name, class Init>
255 class LookAheadMatcher<MatcherFst<F, M, Name, Init>> {
256  public:
257   using FST = MatcherFst<F, M, Name, Init>;
258   using Arc = typename F::Arc;
259   using Label = typename Arc::Label;
260   using StateId = typename Arc::StateId;
261   using Weight = typename Arc::Weight;
262
263   LookAheadMatcher(const FST &fst, MatchType match_type)
264       : matcher_(fst.InitMatcher(match_type)) {}
265
266   LookAheadMatcher(const LookAheadMatcher<FST> &matcher, bool safe = false)
267       : matcher_(matcher.matcher_->Copy(safe)) {}
268
269   // General matcher methods.
270   LookAheadMatcher<FST> *Copy(bool safe = false) const {
271     return new LookAheadMatcher<FST>(*this, safe);
272   }
273
274   MatchType Type(bool test) const { return matcher_->Type(test); }
275
276   void SetState(StateId s) { matcher_->SetState(s); }
277
278   bool Find(Label label) { return matcher_->Find(label); }
279
280   bool Done() const { return matcher_->Done(); }
281
282   const Arc &Value() const { return matcher_->Value(); }
283
284   void Next() { matcher_->Next(); }
285
286   const FST &GetFst() const { return matcher_->GetFst(); }
287
288   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
289
290   uint32 Flags() const { return matcher_->Flags(); }
291
292   bool LookAheadLabel(Label label) const {
293     return matcher_->LookAheadLabel(label);
294   }
295
296   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
297     return matcher_->LookAheadFst(fst, s);
298   }
299
300   Weight LookAheadWeight() const { return matcher_->LookAheadWeight(); }
301
302   bool LookAheadPrefix(Arc *arc) const {
303     return matcher_->LookAheadPrefix(arc);
304   }
305
306   void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
307     matcher_->InitLookAheadFst(fst, copy);
308   }
309
310  private:
311   std::unique_ptr<M> matcher_;
312 };
313
314 // Useful aliases when using StdArc.
315
316 extern const char arc_lookahead_fst_type[];
317
318 using StdArcLookAheadFst =
319     MatcherFst<ConstFst<StdArc>,
320                ArcLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>>,
321                arc_lookahead_fst_type>;
322
323 extern const char ilabel_lookahead_fst_type[];
324 extern const char olabel_lookahead_fst_type[];
325
326 constexpr auto ilabel_lookahead_flags =
327     kInputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix |
328     kLookAheadEpsilons | kLookAheadNonEpsilonPrefix;
329
330 constexpr auto olabel_lookahead_flags =
331     kOutputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix |
332     kLookAheadEpsilons | kLookAheadNonEpsilonPrefix;
333
334 using StdILabelLookAheadFst = MatcherFst<
335     ConstFst<StdArc>,
336     LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>,
337                           ilabel_lookahead_flags, FastLogAccumulator<StdArc>>,
338     ilabel_lookahead_fst_type, LabelLookAheadRelabeler<StdArc>>;
339
340 using StdOLabelLookAheadFst = MatcherFst<
341     ConstFst<StdArc>,
342     LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>,
343                           olabel_lookahead_flags, FastLogAccumulator<StdArc>>,
344     olabel_lookahead_fst_type, LabelLookAheadRelabeler<StdArc>>;
345
346 }  // namespace fst
347
348 #endif  // FST_LIB_MATCHER_FST_H_