75fbd0d26bdb2a980a4a40120f9573624af84e51
[platform/upstream/openfst.git] / src / include / fst / matcher.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to allow matching labels leaving FST states.
5
6 #ifndef FST_LIB_MATCHER_H_
7 #define FST_LIB_MATCHER_H_
8
9 #include <algorithm>
10 #include <unordered_map>
11 #include <utility>
12
13 #include <fst/log.h>
14
15 #include <fst/mutable-fst.h>  // for all internal FST accessors.
16
17
18 namespace fst {
19
20 // Matchers find and iterate through requested labels at FST states. In the
21 // simplest form, these are just some associative map or search keyed on labels.
22 // More generally, they may implement matching special labels that represent
23 // sets of labels such as sigma (all), rho (rest), or phi (fail). The Matcher
24 // interface is:
25 //
26 // template <class F>
27 // class Matcher {
28 //  public:
29 //   using FST = F;
30 //   using Arc = typename FST::Arc;
31 //   using Label = typename Arc::Label;
32 //   using StateId = typename Arc::StateId;
33 //   using Weight = typename Arc::Weight;
34 //
35 //   // Required constructors.
36 //
37 //   Matcher(const FST &fst, MatchType type);
38 //   Matcher(const Matcher &matcher, bool safe = false);
39 //
40 //   // Standard copy method.
41 //   Matcher<FST> *Copy(bool safe = false) const override;
42 //
43 //   // Returns the match type that can be provided (depending on compatibility
44 //   of the input FST). It is either the requested match type, MATCH_NONE, or
45 //   MATCH_UNKNOWN. If test is false, a costly testing is avoided, but
46 //   MATCH_UNKNOWN may be returned. If test is true, a definite answer is
47 //   returned, but may involve more costly computation (e.g., visiting the FST).
48 //   MatchType Type(bool test) const override;
49 //
50 //   // Specifies the current state.
51 //   void SetState(StateId s) final;
52 //
53 //   // Finds matches to a label at the current state, returning true if a match
54 //   // found. kNoLabel matches any non-consuming transitions, e.g., epsilon
55 //   // transitions, which do not require a matching symbol.
56 //   bool Find(Label label) final;
57 //
58 //   // Iterator methods. Note that initially and after SetState() these have
59 //   undefined behavior until Find() is called.
60 //
61 //   bool Done() const final;
62 //
63 //   const Arc &Value() const final;
64 //
65 //   void Next() final;
66 //
67 //   // Returns final weight of a state.
68 //   Weight Final(StateId) const final;
69 //
70 //   // Indicates preference for being the side used for matching in
71 //   // composition. If the value is kRequirePriority, then it is
72 //   // mandatory that it be used. Calling this method without passing the
73 //   // current state of the matcher invalidates the state of the matcher.
74 //   ssize_t Priority(StateId s) final;
75 //
76 //   // This specifies the known FST properties as viewed from this matcher. It
77 //   // takes as argument the input FST's known properties.
78 //   uint64 Properties(uint64 props) const override;
79 //
80 //   // Returns matcher flags.
81 //   uint32 Flags() const override;
82 //
83 //   // Returns matcher FST.
84 //   const FST &GetFst() const override;
85 // };
86
87 // Basic matcher flags.
88
89 // Matcher needs to be used as the matching side in composition for
90 // at least one state (has kRequirePriority).
91 constexpr uint32 kRequireMatch = 0x00000001;
92
93 // Flags used for basic matchers (see also lookahead.h).
94 constexpr uint32 kMatcherFlags = kRequireMatch;
95
96 // Matcher priority that is mandatory.
97 constexpr ssize_t kRequirePriority = -1;
98
99 // Matcher interface, templated on the Arc definition; used for matcher
100 // specializations that are returned by the InitMatcher FST method.
101 template <class A>
102 class MatcherBase {
103  public:
104   using Arc = A;
105   using Label = typename Arc::Label;
106   using StateId = typename Arc::StateId;
107   using Weight = typename Arc::Weight;
108
109   virtual ~MatcherBase() {}
110
111   // Virtual interface.
112
113   virtual MatcherBase<Arc> *Copy(bool safe = false) const = 0;
114   virtual MatchType Type(bool) const = 0;
115   virtual void SetState(StateId) = 0;
116   virtual bool Find(Label) = 0;
117   virtual bool Done() const = 0;
118   virtual const Arc &Value() const = 0;
119   virtual void Next() = 0;
120   virtual const Fst<Arc> &GetFst() const = 0;
121   virtual uint64 Properties(uint64) const = 0;
122
123   // Trivial implementations that can be used by derived classes. Full
124   // devirtualization is expected for any derived class marked final.
125   virtual uint32 Flags() const { return 0; }
126
127   virtual Weight Final(StateId s) const { return internal::Final(GetFst(), s); }
128
129   virtual ssize_t Priority(StateId s) { return internal::NumArcs(GetFst(), s); }
130 };
131
132 // A matcher that expects sorted labels on the side to be matched.
133 // If match_type == MATCH_INPUT, epsilons match the implicit self-loop
134 // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
135 // actual epsilon transitions. If match_type == MATCH_OUTPUT, then
136 // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
137 template <class F>
138 class SortedMatcher : public MatcherBase<typename F::Arc> {
139  public:
140   using FST = F;
141   using Arc = typename FST::Arc;
142   using Label = typename Arc::Label;
143   using StateId = typename Arc::StateId;
144   using Weight = typename Arc::Weight;
145
146   using MatcherBase<Arc>::Flags;
147   using MatcherBase<Arc>::Properties;
148
149   // Labels >= binary_label will be searched for by binary search;
150   // o.w. linear search is used.
151   SortedMatcher(const FST &fst, MatchType match_type, Label binary_label = 1)
152       : fst_(fst.Copy()),
153         state_(kNoStateId),
154         aiter_(nullptr),
155         match_type_(match_type),
156         binary_label_(binary_label),
157         match_label_(kNoLabel),
158         narcs_(0),
159         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
160         error_(false),
161         aiter_pool_(1) {
162     switch (match_type_) {
163       case MATCH_INPUT:
164       case MATCH_NONE:
165         break;
166       case MATCH_OUTPUT:
167         std::swap(loop_.ilabel, loop_.olabel);
168         break;
169       default:
170         FSTERROR() << "SortedMatcher: Bad match type";
171         match_type_ = MATCH_NONE;
172         error_ = true;
173     }
174   }
175
176   SortedMatcher(const SortedMatcher<FST> &matcher, bool safe = false)
177       : fst_(matcher.fst_->Copy(safe)),
178         state_(kNoStateId),
179         aiter_(nullptr),
180         match_type_(matcher.match_type_),
181         binary_label_(matcher.binary_label_),
182         match_label_(kNoLabel),
183         narcs_(0),
184         loop_(matcher.loop_),
185         error_(matcher.error_),
186         aiter_pool_(1) {}
187
188   ~SortedMatcher() override { Destroy(aiter_, &aiter_pool_); }
189
190   SortedMatcher<FST> *Copy(bool safe = false) const override {
191     return new SortedMatcher<FST>(*this, safe);
192   }
193
194   MatchType Type(bool test) const override {
195     if (match_type_ == MATCH_NONE) return match_type_;
196     const auto true_prop =
197         match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
198     const auto false_prop =
199         match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
200     const auto props = fst_->Properties(true_prop | false_prop, test);
201     if (props & true_prop) {
202       return match_type_;
203     } else if (props & false_prop) {
204       return MATCH_NONE;
205     } else {
206       return MATCH_UNKNOWN;
207     }
208   }
209
210   void SetState(StateId s) final {
211     if (state_ == s) return;
212     state_ = s;
213     if (match_type_ == MATCH_NONE) {
214       FSTERROR() << "SortedMatcher: Bad match type";
215       error_ = true;
216     }
217     Destroy(aiter_, &aiter_pool_);
218     aiter_ = new (&aiter_pool_) ArcIterator<FST>(*fst_, s);
219     aiter_->SetFlags(kArcNoCache, kArcNoCache);
220     narcs_ = internal::NumArcs(*fst_, s);
221     loop_.nextstate = s;
222   }
223
224   bool Find(Label match_label) final {
225     exact_match_ = true;
226     if (error_) {
227       current_loop_ = false;
228       match_label_ = kNoLabel;
229       return false;
230     }
231     current_loop_ = match_label == 0;
232     match_label_ = match_label == kNoLabel ? 0 : match_label;
233     if (Search()) {
234       return true;
235     } else {
236       return current_loop_;
237     }
238   }
239
240   // Positions matcher to the first position where inserting match_label would
241   // maintain the sort order.
242   void LowerBound(Label label) {
243     exact_match_ = false;
244     current_loop_ = false;
245     if (error_) {
246       match_label_ = kNoLabel;
247       return;
248     }
249     match_label_ = label;
250     Search();
251   }
252
253   // After Find(), returns false if no more exact matches.
254   // After LowerBound(), returns false if no more arcs.
255   bool Done() const final {
256     if (current_loop_) return false;
257     if (aiter_->Done()) return true;
258     if (!exact_match_) return false;
259     aiter_->SetFlags(
260         match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
261         kArcValueFlags);
262     const auto label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel
263                                                   : aiter_->Value().olabel;
264     return label != match_label_;
265   }
266
267   const Arc &Value() const final {
268     if (current_loop_) return loop_;
269     aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
270     return aiter_->Value();
271   }
272
273   void Next() final {
274     if (current_loop_) {
275       current_loop_ = false;
276     } else {
277       aiter_->Next();
278     }
279   }
280
281   Weight Final(StateId s) const final {
282     return MatcherBase<Arc>::Final(s);
283   }
284
285   ssize_t Priority(StateId s) final {
286     return MatcherBase<Arc>::Priority(s);
287   }
288
289   const FST &GetFst() const override { return *fst_; }
290
291   uint64 Properties(uint64 inprops) const override {
292     return inprops | (error_ ? kError : 0);
293   }
294
295   size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
296
297  private:
298   bool Search();
299
300   std::unique_ptr<const FST> fst_;
301   StateId state_;            // Matcher state.
302   ArcIterator<FST> *aiter_;  // Iterator for current state.
303   MatchType match_type_;     // Type of match to perform.
304   Label binary_label_;       // Least label for binary search.
305   Label match_label_;        // Current label to be matched.
306   size_t narcs_;             // Current state arc count.
307   Arc loop_;                 // For non-consuming symbols.
308   bool current_loop_;        // Current arc is the implicit loop.
309   bool exact_match_;         // Exact match or lower bound?
310   bool error_;               // Error encountered?
311   MemoryPool<ArcIterator<FST>> aiter_pool_;  // Pool of arc iterators.
312 };
313
314 // Returns true iff match to match_label_, positioning arc iterator at lower
315 // bound.
316 template <class FST>
317 inline bool SortedMatcher<FST>::Search() {
318   aiter_->SetFlags(
319       match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
320       kArcValueFlags);
321   if (match_label_ >= binary_label_) {
322     // Binary search for match.
323     size_t low = 0;
324     size_t high = narcs_;
325     while (low < high) {
326       const size_t mid = (low + high) / 2;
327       aiter_->Seek(mid);
328       auto label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel
329                                               : aiter_->Value().olabel;
330       if (label > match_label_) {
331         high = mid;
332       } else if (label < match_label_) {
333         low = mid + 1;
334       } else {
335         // Finds first matching label (when non-deterministic).
336         for (size_t i = mid; i > low; --i) {
337           aiter_->Seek(i - 1);
338           label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel
339                                              : aiter_->Value().olabel;
340           if (label != match_label_) {
341             aiter_->Seek(i);
342             return true;
343           }
344         }
345         return true;
346       }
347     }
348     aiter_->Seek(low);
349     return false;
350   } else {
351     // Linear search for match.
352     for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
353       const auto label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel
354                                                     : aiter_->Value().olabel;
355       if (label == match_label_) return true;
356       if (label > match_label_) break;
357     }
358     return false;
359   }
360 }
361
362 // A matcher that stores labels in a per-state hash table populated upon the
363 // first visit to that state. Sorting is not required. Treatment of
364 // epsilons are the same as with SortedMatcher.
365 template <class F>
366 class HashMatcher : public MatcherBase<typename F::Arc> {
367  public:
368   using FST = F;
369   using Arc = typename FST::Arc;
370   using Label = typename Arc::Label;
371   using StateId = typename Arc::StateId;
372   using Weight = typename Arc::Weight;
373
374   using MatcherBase<Arc>::Flags;
375   using MatcherBase<Arc>::Final;
376   using MatcherBase<Arc>::Priority;
377
378   HashMatcher(const FST &fst, MatchType match_type)
379       : fst_(fst.Copy()),
380         state_(kNoStateId),
381         match_type_(match_type),
382         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
383     switch (match_type_) {
384       case MATCH_INPUT:
385       case MATCH_NONE:
386         break;
387       case MATCH_OUTPUT:
388         std::swap(loop_.ilabel, loop_.olabel);
389         break;
390       default:
391         FSTERROR() << "SortedMatcher: Bad match type";
392         match_type_ = MATCH_NONE;
393         error_ = true;
394     }
395   }
396
397   HashMatcher(const HashMatcher<FST> &matcher, bool safe = false)
398       : fst_(matcher.fst_->Copy(safe)),
399         state_(kNoStateId),
400         match_type_(matcher.match_type_),
401         loop_(matcher.loop_),
402         error_(matcher.error_) {}
403
404   HashMatcher<FST> *Copy(bool safe = false) const override {
405     return new HashMatcher<FST>(*this, safe);
406   }
407
408   // The argument is ignored as there are no relevant properties to test.
409   MatchType Type(bool test) const override { return match_type_; }
410
411   void SetState(StateId s) final;
412
413   bool Find(Label label) final {
414     current_loop_ = label == 0;
415     if (label == 0) {
416       Search(label);
417       return true;
418     }
419     if (label == kNoLabel) label = 0;
420     return Search(label);
421   }
422
423   bool Done() const final {
424     if (current_loop_) return false;
425     return label_it_ == label_end_;
426   }
427
428   const Arc &Value() const final {
429     if (current_loop_) return loop_;
430     aiter_->Seek(label_it_->second);
431     return aiter_->Value();
432   }
433
434   void Next() final {
435     if (current_loop_) {
436       current_loop_ = false;
437     } else {
438       ++label_it_;
439     }
440   }
441
442   const FST &GetFst() const override { return *fst_; }
443
444   uint64 Properties(uint64 inprops) const override {
445     return inprops | (error_ ? kError : 0);
446   }
447
448  private:
449   bool Search(Label match_label);
450
451   using LabelTable = std::unordered_multimap<Label, size_t>;
452   using StateTable = std::unordered_map<StateId, LabelTable>;
453
454   std::unique_ptr<const FST> fst_;
455   StateId state_;  // Matcher state.
456   MatchType match_type_;
457   Arc loop_;           // The implicit loop itself.
458   bool current_loop_;  // Is the current arc is the implicit loop?
459   bool error_;         // Error encountered?
460   std::unique_ptr<ArcIterator<FST>> aiter_;
461   StateTable state_table_;   // Table from states to label table.
462   LabelTable *label_table_;  // Pointer to current state's label table.
463   typename LabelTable::iterator label_it_;   // Position for label.
464   typename LabelTable::iterator label_end_;  // Position for last label + 1.
465 };
466
467 template <class FST>
468 void HashMatcher<FST>::SetState(typename FST::Arc::StateId s) {
469   if (state_ == s) return;
470   // Resets everything for the state.
471   state_ = s;
472   loop_.nextstate = state_;
473   aiter_.reset(new ArcIterator<FST>(*fst_, state_));
474   if (match_type_ == MATCH_NONE) {
475     FSTERROR() << "HashMatcher: Bad match type";
476     error_ = false;
477   }
478   // Attempts to insert a new label table; if it already exists,
479   // no additional work is done and we simply return.
480   auto it_and_success = state_table_.emplace(state_, LabelTable());
481   if (!it_and_success.second) return;
482   // Otherwise, populate this new table.
483   // Sets instance's pointer to the label table for this state.
484   label_table_ = &(it_and_success.first->second);
485   // Populates the label table.
486   label_table_->reserve(internal::NumArcs(*fst_, state_));
487   const auto aiter_flags =
488       (match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue) |
489       kArcNoCache;
490   aiter_->SetFlags(aiter_flags, kArcFlags);
491   for (; !aiter_->Done(); aiter_->Next()) {
492     const auto label = (match_type_ == MATCH_INPUT) ? aiter_->Value().ilabel
493                                                     : aiter_->Value().olabel;
494     label_table_->emplace(label, aiter_->Position());
495   }
496   aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
497 }
498
499 template <class FST>
500 inline bool HashMatcher<FST>::Search(typename FST::Arc::Label match_label) {
501   auto range = label_table_->equal_range(match_label);
502   if (range.first == range.second) return false;
503   label_it_ = range.first;
504   label_end_ = range.second;
505   aiter_->Seek(label_it_->second);
506   return true;
507 }
508
509 // Specifies whether we rewrite both the input and output sides during matching.
510 enum MatcherRewriteMode {
511   MATCHER_REWRITE_AUTO = 0,  // Rewrites both sides iff acceptor.
512   MATCHER_REWRITE_ALWAYS,
513   MATCHER_REWRITE_NEVER
514 };
515
516 // For any requested label that doesn't match at a state, this matcher
517 // considers the *unique* transition that matches the label 'phi_label'
518 // (phi = 'fail'), and recursively looks for a match at its
519 // destination.  When 'phi_loop' is true, if no match is found but a
520 // phi self-loop is found, then the phi transition found is returned
521 // with the phi_label rewritten as the requested label (both sides if
522 // an acceptor, or if 'rewrite_both' is true and both input and output
523 // labels of the found transition are 'phi_label').  If 'phi_label' is
524 // kNoLabel, this special matching is not done.  PhiMatcher is
525 // templated itself on a matcher, which is used to perform the
526 // underlying matching.  By default, the underlying matcher is
527 // constructed by PhiMatcher. The user can instead pass in this
528 // object; in that case, PhiMatcher takes its ownership.
529 // Phi non-determinism not supported. No non-consuming symbols other
530 // than epsilon supported with the underlying template argument matcher.
531 template <class M>
532 class PhiMatcher : public MatcherBase<typename M::Arc> {
533  public:
534   using FST = typename M::FST;
535   using Arc = typename FST::Arc;
536   using Label = typename Arc::Label;
537   using StateId = typename Arc::StateId;
538   using Weight = typename Arc::Weight;
539
540   PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel,
541              bool phi_loop = true,
542              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
543              M *matcher = nullptr)
544       : matcher_(matcher ? matcher : new M(fst, match_type)),
545         match_type_(match_type),
546         phi_label_(phi_label),
547         state_(kNoStateId),
548         phi_loop_(phi_loop),
549         error_(false) {
550     if (match_type == MATCH_BOTH) {
551       FSTERROR() << "PhiMatcher: Bad match type";
552       match_type_ = MATCH_NONE;
553       error_ = true;
554     }
555     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
556       rewrite_both_ = fst.Properties(kAcceptor, true);
557     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
558       rewrite_both_ = true;
559     } else {
560       rewrite_both_ = false;
561     }
562   }
563
564   PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false)
565       : matcher_(new M(*matcher.matcher_, safe)),
566         match_type_(matcher.match_type_),
567         phi_label_(matcher.phi_label_),
568         rewrite_both_(matcher.rewrite_both_),
569         state_(kNoStateId),
570         phi_loop_(matcher.phi_loop_),
571         error_(matcher.error_) {}
572
573   PhiMatcher<M> *Copy(bool safe = false) const override {
574     return new PhiMatcher<M>(*this, safe);
575   }
576
577   MatchType Type(bool test) const override { return matcher_->Type(test); }
578
579   void SetState(StateId s) final {
580     if (state_ == s) return;
581     matcher_->SetState(s);
582     state_ = s;
583     has_phi_ = phi_label_ != kNoLabel;
584   }
585
586   bool Find(Label match_label) final;
587
588   bool Done() const final { return matcher_->Done(); }
589
590   const Arc &Value() const final {
591     if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
592       return matcher_->Value();
593     } else if (phi_match_ == 0) {  // Virtual epsilon loop.
594       phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
595       if (match_type_ == MATCH_OUTPUT) {
596         std::swap(phi_arc_.ilabel, phi_arc_.olabel);
597       }
598       return phi_arc_;
599     } else {
600       phi_arc_ = matcher_->Value();
601       phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
602       if (phi_match_ != kNoLabel) {  // Phi loop match.
603         if (rewrite_both_) {
604           if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_;
605           if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_;
606         } else if (match_type_ == MATCH_INPUT) {
607           phi_arc_.ilabel = phi_match_;
608         } else {
609           phi_arc_.olabel = phi_match_;
610         }
611       }
612       return phi_arc_;
613     }
614   }
615
616   void Next() final { matcher_->Next(); }
617
618   Weight Final(StateId s) const final {
619     auto weight = matcher_->Final(s);
620     if (phi_label_ == kNoLabel || weight != Weight::Zero()) {
621       return weight;
622     }
623     weight = Weight::One();
624     matcher_->SetState(s);
625     while (matcher_->Final(s) == Weight::Zero()) {
626       if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) break;
627       weight = Times(weight, matcher_->Value().weight);
628       if (s == matcher_->Value().nextstate) {
629         return Weight::Zero();  // Does not follow phi self-loops.
630       }
631       s = matcher_->Value().nextstate;
632       matcher_->SetState(s);
633     }
634     weight = Times(weight, matcher_->Final(s));
635     return weight;
636   }
637
638   ssize_t Priority(StateId s) final {
639     if (phi_label_ != kNoLabel) {
640       matcher_->SetState(s);
641       const bool has_phi = matcher_->Find(phi_label_ == 0 ? -1 : phi_label_);
642       return has_phi ? kRequirePriority : matcher_->Priority(s);
643     } else {
644       return matcher_->Priority(s);
645     }
646   }
647
648   const FST &GetFst() const override { return matcher_->GetFst(); }
649
650   uint64 Properties(uint64 props) const override;
651
652   uint32 Flags() const override {
653     if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) {
654       return matcher_->Flags();
655     }
656     return matcher_->Flags() | kRequireMatch;
657   }
658
659   Label PhiLabel() const { return phi_label_; }
660
661  private:
662   mutable std::unique_ptr<M> matcher_;
663   MatchType match_type_;  // Type of match requested.
664   Label phi_label_;       // Label that represents the phi transition.
665   bool rewrite_both_;     // Rewrite both sides when both are phi_label_?
666   bool has_phi_;          // Are there possibly phis at the current state?
667   Label phi_match_;       // Current label that matches phi loop.
668   mutable Arc phi_arc_;   // Arc to return.
669   StateId state_;         // Matcher state.
670   Weight phi_weight_;     // Product of the weights of phi transitions taken.
671   bool phi_loop_;         // When true, phi self-loop are allowed and treated
672                           // as rho (required for Aho-Corasick).
673   bool error_;            // Error encountered?
674
675   PhiMatcher &operator=(const PhiMatcher &) = delete;
676 };
677
678 template <class M>
679 inline bool PhiMatcher<M>::Find(Label label) {
680   if (label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
681     FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
682     error_ = true;
683     return false;
684   }
685   matcher_->SetState(state_);
686   phi_match_ = kNoLabel;
687   phi_weight_ = Weight::One();
688   // If phi_label_ == 0, there are no more true epsilon arcs.
689   if (phi_label_ == 0) {
690     if (label == kNoLabel) {
691       return false;
692     }
693     if (label == 0) {  // but a virtual epsilon loop needs to be returned.
694       if (!matcher_->Find(kNoLabel)) {
695         return matcher_->Find(0);
696       } else {
697         phi_match_ = 0;
698         return true;
699       }
700     }
701   }
702   if (!has_phi_ || label == 0 || label == kNoLabel) {
703     return matcher_->Find(label);
704   }
705   auto s = state_;
706   while (!matcher_->Find(label)) {
707     // Look for phi transition (if phi_label_ == 0, we need to look
708     // for -1 to avoid getting the virtual self-loop)
709     if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false;
710     if (phi_loop_ && matcher_->Value().nextstate == s) {
711       phi_match_ = label;
712       return true;
713     }
714     phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
715     s = matcher_->Value().nextstate;
716     matcher_->Next();
717     if (!matcher_->Done()) {
718       FSTERROR() << "PhiMatcher: Phi non-determinism not supported";
719       error_ = true;
720     }
721     matcher_->SetState(s);
722   }
723   return true;
724 }
725
726 template <class M>
727 inline uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
728   auto outprops = matcher_->Properties(inprops);
729   if (error_) outprops |= kError;
730   if (match_type_ == MATCH_NONE) {
731     return outprops;
732   } else if (match_type_ == MATCH_INPUT) {
733     if (phi_label_ == 0) {
734       outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
735       outprops |= kNoEpsilons | kNoIEpsilons;
736     }
737     if (rewrite_both_) {
738       return outprops &
739              ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
740                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
741     } else {
742       return outprops &
743              ~(kODeterministic | kAcceptor | kString | kILabelSorted |
744                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
745     }
746   } else if (match_type_ == MATCH_OUTPUT) {
747     if (phi_label_ == 0) {
748       outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
749       outprops |= kNoEpsilons | kNoOEpsilons;
750     }
751     if (rewrite_both_) {
752       return outprops &
753              ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
754                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
755     } else {
756       return outprops &
757              ~(kIDeterministic | kAcceptor | kString | kILabelSorted |
758                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
759     }
760   } else {
761     // Shouldn't ever get here.
762     FSTERROR() << "PhiMatcher: Bad match type: " << match_type_;
763     return 0;
764   }
765 }
766
767 // For any requested label that doesn't match at a state, this matcher
768 // considers all transitions that match the label 'rho_label' (rho =
769 // 'rest').  Each such rho transition found is returned with the
770 // rho_label rewritten as the requested label (both sides if an
771 // acceptor, or if 'rewrite_both' is true and both input and output
772 // labels of the found transition are 'rho_label').  If 'rho_label' is
773 // kNoLabel, this special matching is not done.  RhoMatcher is
774 // templated itself on a matcher, which is used to perform the
775 // underlying matching.  By default, the underlying matcher is
776 // constructed by RhoMatcher.  The user can instead pass in this
777 // object; in that case, RhoMatcher takes its ownership.
778 // No non-consuming symbols other than epsilon supported with
779 // the underlying template argument matcher.
780 template <class M>
781 class RhoMatcher : public MatcherBase<typename M::Arc> {
782  public:
783   using FST = typename M::FST;
784   using Arc = typename FST::Arc;
785   using Label = typename Arc::Label;
786   using StateId = typename Arc::StateId;
787   using Weight = typename Arc::Weight;
788
789   RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel,
790              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
791              M *matcher = nullptr)
792       : matcher_(matcher ? matcher : new M(fst, match_type)),
793         match_type_(match_type),
794         rho_label_(rho_label),
795         error_(false),
796         state_(kNoStateId) {
797     if (match_type == MATCH_BOTH) {
798       FSTERROR() << "RhoMatcher: Bad match type";
799       match_type_ = MATCH_NONE;
800       error_ = true;
801     }
802     if (rho_label == 0) {
803       FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
804       rho_label_ = kNoLabel;
805       error_ = true;
806     }
807     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
808       rewrite_both_ = fst.Properties(kAcceptor, true);
809     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
810       rewrite_both_ = true;
811     } else {
812       rewrite_both_ = false;
813     }
814   }
815
816   RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false)
817       : matcher_(new M(*matcher.matcher_, safe)),
818         match_type_(matcher.match_type_),
819         rho_label_(matcher.rho_label_),
820         rewrite_both_(matcher.rewrite_both_),
821         error_(matcher.error_),
822         state_(kNoStateId) {}
823
824   RhoMatcher<M> *Copy(bool safe = false) const override {
825     return new RhoMatcher<M>(*this, safe);
826   }
827
828   MatchType Type(bool test) const override { return matcher_->Type(test); }
829
830   void SetState(StateId s) final {
831     if (state_ == s) return;
832     state_ = s;
833     matcher_->SetState(s);
834     has_rho_ = rho_label_ != kNoLabel;
835   }
836
837   bool Find(Label label) final {
838     if (label == rho_label_ && rho_label_ != kNoLabel) {
839       FSTERROR() << "RhoMatcher::Find: bad label (rho)";
840       error_ = true;
841       return false;
842     }
843     if (matcher_->Find(label)) {
844       rho_match_ = kNoLabel;
845       return true;
846     } else if (has_rho_ && label != 0 && label != kNoLabel &&
847                (has_rho_ = matcher_->Find(rho_label_))) {
848       rho_match_ = label;
849       return true;
850     } else {
851       return false;
852     }
853   }
854
855   bool Done() const final { return matcher_->Done(); }
856
857   const Arc &Value() const final {
858     if (rho_match_ == kNoLabel) {
859       return matcher_->Value();
860     } else {
861       rho_arc_ = matcher_->Value();
862       if (rewrite_both_) {
863         if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_;
864         if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_;
865       } else if (match_type_ == MATCH_INPUT) {
866         rho_arc_.ilabel = rho_match_;
867       } else {
868         rho_arc_.olabel = rho_match_;
869       }
870       return rho_arc_;
871     }
872   }
873
874   void Next() final { matcher_->Next(); }
875
876   Weight Final(StateId s) const final { return matcher_->Final(s); }
877
878   ssize_t Priority(StateId s) final {
879     state_ = s;
880     matcher_->SetState(s);
881     has_rho_ = matcher_->Find(rho_label_);
882     if (has_rho_) {
883       return kRequirePriority;
884     } else {
885       return matcher_->Priority(s);
886     }
887   }
888
889   const FST &GetFst() const override { return matcher_->GetFst(); }
890
891   uint64 Properties(uint64 props) const override;
892
893   uint32 Flags() const override {
894     if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) {
895       return matcher_->Flags();
896     }
897     return matcher_->Flags() | kRequireMatch;
898   }
899
900   Label RhoLabel() const { return rho_label_; }
901
902  private:
903   std::unique_ptr<M> matcher_;
904   MatchType match_type_;  // Type of match requested.
905   Label rho_label_;       // Label that represents the rho transition
906   bool rewrite_both_;     // Rewrite both sides when both are rho_label_?
907   bool has_rho_;          // Are there possibly rhos at the current state?
908   Label rho_match_;       // Current label that matches rho transition.
909   mutable Arc rho_arc_;   // Arc to return when rho match.
910   bool error_;            // Error encountered?
911   StateId state_;         // Matcher state.
912 };
913
914 template <class M>
915 inline uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
916   auto outprops = matcher_->Properties(inprops);
917   if (error_) outprops |= kError;
918   if (match_type_ == MATCH_NONE) {
919     return outprops;
920   } else if (match_type_ == MATCH_INPUT) {
921     if (rewrite_both_) {
922       return outprops &
923              ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
924                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
925     } else {
926       return outprops &
927              ~(kODeterministic | kAcceptor | kString | kILabelSorted |
928                kNotILabelSorted);
929     }
930   } else if (match_type_ == MATCH_OUTPUT) {
931     if (rewrite_both_) {
932       return outprops &
933              ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
934                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
935     } else {
936       return outprops &
937              ~(kIDeterministic | kAcceptor | kString | kOLabelSorted |
938                kNotOLabelSorted);
939     }
940   } else {
941     // Shouldn't ever get here.
942     FSTERROR() << "RhoMatcher: Bad match type: " << match_type_;
943     return 0;
944   }
945 }
946
947 // For any requested label, this matcher considers all transitions
948 // that match the label 'sigma_label' (sigma = "any"), and this in
949 // additions to transitions with the requested label.  Each such sigma
950 // transition found is returned with the sigma_label rewritten as the
951 // requested label (both sides if an acceptor, or if 'rewrite_both' is
952 // true and both input and output labels of the found transition are
953 // 'sigma_label').  If 'sigma_label' is kNoLabel, this special
954 // matching is not done.  SigmaMatcher is templated itself on a
955 // matcher, which is used to perform the underlying matching.  By
956 // default, the underlying matcher is constructed by SigmaMatcher.
957 // The user can instead pass in this object; in that case,
958 // SigmaMatcher takes its ownership.  No non-consuming symbols other
959 // than epsilon supported with the underlying template argument matcher.
960 template <class M>
961 class SigmaMatcher : public MatcherBase<typename M::Arc> {
962  public:
963   using FST = typename M::FST;
964   using Arc = typename FST::Arc;
965   using Label = typename Arc::Label;
966   using StateId = typename Arc::StateId;
967   using Weight = typename Arc::Weight;
968
969   SigmaMatcher(const FST &fst, MatchType match_type,
970                Label sigma_label = kNoLabel,
971                MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
972                M *matcher = nullptr)
973       : matcher_(matcher ? matcher : new M(fst, match_type)),
974         match_type_(match_type),
975         sigma_label_(sigma_label),
976         error_(false),
977         state_(kNoStateId) {
978     if (match_type == MATCH_BOTH) {
979       FSTERROR() << "SigmaMatcher: Bad match type";
980       match_type_ = MATCH_NONE;
981       error_ = true;
982     }
983     if (sigma_label == 0) {
984       FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
985       sigma_label_ = kNoLabel;
986       error_ = true;
987     }
988     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
989       rewrite_both_ = fst.Properties(kAcceptor, true);
990     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
991       rewrite_both_ = true;
992     } else {
993       rewrite_both_ = false;
994     }
995   }
996
997   SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false)
998       : matcher_(new M(*matcher.matcher_, safe)),
999         match_type_(matcher.match_type_),
1000         sigma_label_(matcher.sigma_label_),
1001         rewrite_both_(matcher.rewrite_both_),
1002         error_(matcher.error_),
1003         state_(kNoStateId) {}
1004
1005   SigmaMatcher<M> *Copy(bool safe = false) const override {
1006     return new SigmaMatcher<M>(*this, safe);
1007   }
1008
1009   MatchType Type(bool test) const override { return matcher_->Type(test); }
1010
1011   void SetState(StateId s) final {
1012     if (state_ == s) return;
1013     state_ = s;
1014     matcher_->SetState(s);
1015     has_sigma_ =
1016         (sigma_label_ != kNoLabel) ? matcher_->Find(sigma_label_) : false;
1017   }
1018
1019   bool Find(Label match_label) final {
1020     match_label_ = match_label;
1021     if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
1022       FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
1023       error_ = true;
1024       return false;
1025     }
1026     if (matcher_->Find(match_label)) {
1027       sigma_match_ = kNoLabel;
1028       return true;
1029     } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
1030                matcher_->Find(sigma_label_)) {
1031       sigma_match_ = match_label;
1032       return true;
1033     } else {
1034       return false;
1035     }
1036   }
1037
1038   bool Done() const final { return matcher_->Done(); }
1039
1040   const Arc &Value() const final {
1041     if (sigma_match_ == kNoLabel) {
1042       return matcher_->Value();
1043     } else {
1044       sigma_arc_ = matcher_->Value();
1045       if (rewrite_both_) {
1046         if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_;
1047         if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_;
1048       } else if (match_type_ == MATCH_INPUT) {
1049         sigma_arc_.ilabel = sigma_match_;
1050       } else {
1051         sigma_arc_.olabel = sigma_match_;
1052       }
1053       return sigma_arc_;
1054     }
1055   }
1056
1057   void Next() final {
1058     matcher_->Next();
1059     if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
1060         (match_label_ > 0)) {
1061       matcher_->Find(sigma_label_);
1062       sigma_match_ = match_label_;
1063     }
1064   }
1065
1066   Weight Final(StateId s) const final { return matcher_->Final(s); }
1067
1068   ssize_t Priority(StateId s) final {
1069     if (sigma_label_ != kNoLabel) {
1070       SetState(s);
1071       return has_sigma_ ? kRequirePriority : matcher_->Priority(s);
1072     } else {
1073       return matcher_->Priority(s);
1074     }
1075   }
1076
1077   const FST &GetFst() const override { return matcher_->GetFst(); }
1078
1079   uint64 Properties(uint64 props) const override;
1080
1081   uint32 Flags() const override {
1082     if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) {
1083       return matcher_->Flags();
1084     }
1085     return matcher_->Flags() | kRequireMatch;
1086   }
1087
1088   Label SigmaLabel() const { return sigma_label_; }
1089
1090  private:
1091   std::unique_ptr<M> matcher_;
1092   MatchType match_type_;   // Type of match requested.
1093   Label sigma_label_;      // Label that represents the sigma transition.
1094   bool rewrite_both_;      // Rewrite both sides when both are sigma_label_?
1095   bool has_sigma_;         // Are there sigmas at the current state?
1096   Label sigma_match_;      // Current label that matches sigma transition.
1097   mutable Arc sigma_arc_;  // Arc to return when sigma match.
1098   Label match_label_;      // Label being matched.
1099   bool error_;             // Error encountered?
1100   StateId state_;          // Matcher state.
1101 };
1102
1103 template <class M>
1104 inline uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
1105   auto outprops = matcher_->Properties(inprops);
1106   if (error_) outprops |= kError;
1107   if (match_type_ == MATCH_NONE) {
1108     return outprops;
1109   } else if (rewrite_both_) {
1110     return outprops &
1111            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1112              kNonODeterministic | kILabelSorted | kNotILabelSorted |
1113              kOLabelSorted | kNotOLabelSorted | kString);
1114   } else if (match_type_ == MATCH_INPUT) {
1115     return outprops &
1116            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1117              kNonODeterministic | kILabelSorted | kNotILabelSorted | kString |
1118              kAcceptor);
1119   } else if (match_type_ == MATCH_OUTPUT) {
1120     return outprops &
1121            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1122              kNonODeterministic | kOLabelSorted | kNotOLabelSorted | kString |
1123              kAcceptor);
1124   } else {
1125     // Shouldn't ever get here.
1126     FSTERROR() << "SigmaMatcher: Bad match type: " << match_type_;
1127     return 0;
1128   }
1129 }
1130
1131 // Flags for MultiEpsMatcher.
1132
1133 // Return multi-epsilon arcs for Find(kNoLabel).
1134 const uint32 kMultiEpsList = 0x00000001;
1135
1136 // Return a kNolabel loop for Find(multi_eps).
1137 const uint32 kMultiEpsLoop = 0x00000002;
1138
1139 // MultiEpsMatcher: allows treating multiple non-0 labels as
1140 // non-consuming labels in addition to 0 that is always
1141 // non-consuming. Precise behavior controlled by 'flags' argument. By
1142 // default, the underlying matcher is constructed by
1143 // MultiEpsMatcher. The user can instead pass in this object; in that
1144 // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
1145 // true.
1146 template <class M>
1147 class MultiEpsMatcher {
1148  public:
1149   using FST = typename M::FST;
1150   using Arc = typename FST::Arc;
1151   using Label = typename Arc::Label;
1152   using StateId = typename Arc::StateId;
1153   using Weight = typename Arc::Weight;
1154
1155   MultiEpsMatcher(const FST &fst, MatchType match_type,
1156                   uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1157                   M *matcher = nullptr, bool own_matcher = true)
1158       : matcher_(matcher ? matcher : new M(fst, match_type)),
1159         flags_(flags),
1160         own_matcher_(matcher ? own_matcher : true) {
1161     if (match_type == MATCH_INPUT) {
1162       loop_.ilabel = kNoLabel;
1163       loop_.olabel = 0;
1164     } else {
1165       loop_.ilabel = 0;
1166       loop_.olabel = kNoLabel;
1167     }
1168     loop_.weight = Weight::One();
1169     loop_.nextstate = kNoStateId;
1170   }
1171
1172   MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false)
1173       : matcher_(new M(*matcher.matcher_, safe)),
1174         flags_(matcher.flags_),
1175         own_matcher_(true),
1176         multi_eps_labels_(matcher.multi_eps_labels_),
1177         loop_(matcher.loop_) {
1178     loop_.nextstate = kNoStateId;
1179   }
1180
1181   ~MultiEpsMatcher() {
1182     if (own_matcher_) delete matcher_;
1183   }
1184
1185   MultiEpsMatcher<M> *Copy(bool safe = false) const {
1186     return new MultiEpsMatcher<M>(*this, safe);
1187   }
1188
1189   MatchType Type(bool test) const { return matcher_->Type(test); }
1190
1191   void SetState(StateId state) {
1192     matcher_->SetState(state);
1193     loop_.nextstate = state;
1194   }
1195
1196   bool Find(Label label);
1197
1198   bool Done() const { return done_; }
1199
1200   const Arc &Value() const { return current_loop_ ? loop_ : matcher_->Value(); }
1201
1202   void Next() {
1203     if (!current_loop_) {
1204       matcher_->Next();
1205       done_ = matcher_->Done();
1206       if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
1207         ++multi_eps_iter_;
1208         while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1209                !matcher_->Find(*multi_eps_iter_)) {
1210           ++multi_eps_iter_;
1211         }
1212         if (multi_eps_iter_ != multi_eps_labels_.End()) {
1213           done_ = false;
1214         } else {
1215           done_ = !matcher_->Find(kNoLabel);
1216         }
1217       }
1218     } else {
1219       done_ = true;
1220     }
1221   }
1222
1223   const FST &GetFst() const { return matcher_->GetFst(); }
1224
1225   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
1226
1227   const M *GetMatcher() const { return matcher_; }
1228
1229   Weight Final(StateId s) const { return matcher_->Final(s); }
1230
1231   uint32 Flags() const { return matcher_->Flags(); }
1232
1233   ssize_t Priority(StateId s) { return matcher_->Priority(s); }
1234
1235   void AddMultiEpsLabel(Label label) {
1236     if (label == 0) {
1237       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1238     } else {
1239       multi_eps_labels_.Insert(label);
1240     }
1241   }
1242
1243   void RemoveMultiEpsLabel(Label label) {
1244     if (label == 0) {
1245       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1246     } else {
1247       multi_eps_labels_.Erase(label);
1248     }
1249   }
1250
1251   void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); }
1252
1253  private:
1254   M *matcher_;
1255   uint32 flags_;
1256   bool own_matcher_;  // Does this class delete the matcher?
1257
1258   // Multi-eps label set.
1259   CompactSet<Label, kNoLabel> multi_eps_labels_;
1260   typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1261
1262   bool current_loop_;  // Current arc is the implicit loop?
1263   mutable Arc loop_;   // For non-consuming symbols.
1264   bool done_;          // Matching done?
1265
1266   MultiEpsMatcher &operator=(const MultiEpsMatcher &) = delete;
1267 };
1268
1269 template <class M>
1270 inline bool MultiEpsMatcher<M>::Find(Label label) {
1271   multi_eps_iter_ = multi_eps_labels_.End();
1272   current_loop_ = false;
1273   bool ret;
1274   if (label == 0) {
1275     ret = matcher_->Find(0);
1276   } else if (label == kNoLabel) {
1277     if (flags_ & kMultiEpsList) {
1278       // Returns all non-consuming arcs (including epsilon).
1279       multi_eps_iter_ = multi_eps_labels_.Begin();
1280       while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1281              !matcher_->Find(*multi_eps_iter_)) {
1282         ++multi_eps_iter_;
1283       }
1284       if (multi_eps_iter_ != multi_eps_labels_.End()) {
1285         ret = true;
1286       } else {
1287         ret = matcher_->Find(kNoLabel);
1288       }
1289     } else {
1290       // Returns all epsilon arcs.
1291       ret = matcher_->Find(kNoLabel);
1292     }
1293   } else if ((flags_ & kMultiEpsLoop) &&
1294              multi_eps_labels_.Find(label) != multi_eps_labels_.End()) {
1295     // Returns implicit loop.
1296     current_loop_ = true;
1297     ret = true;
1298   } else {
1299     ret = matcher_->Find(label);
1300   }
1301   done_ = !ret;
1302   return ret;
1303 }
1304
1305 // This class discards any implicit matches (e.g., the implicit epsilon
1306 // self-loops in the SortedMatcher). Matchers are most often used in
1307 // composition/intersection where the implicit matches are needed
1308 // e.g. for epsilon processing. However, if a matcher is simply being
1309 // used to look-up explicit label matches, this class saves the user
1310 // from having to check for and discard the unwanted implicit matches
1311 // themselves.
1312 template <class M>
1313 class ExplicitMatcher : public MatcherBase<typename M::Arc> {
1314  public:
1315   using FST = typename M::FST;
1316   using Arc = typename FST::Arc;
1317   using Label = typename Arc::Label;
1318   using StateId = typename Arc::StateId;
1319   using Weight = typename Arc::Weight;
1320
1321   ExplicitMatcher(const FST &fst, MatchType match_type, M *matcher = nullptr)
1322       : matcher_(matcher ? matcher : new M(fst, match_type)),
1323         match_type_(match_type),
1324         error_(false) {}
1325
1326   ExplicitMatcher(const ExplicitMatcher<M> &matcher, bool safe = false)
1327       : matcher_(new M(*matcher.matcher_, safe)),
1328         match_type_(matcher.match_type_),
1329         error_(matcher.error_) {}
1330
1331   ExplicitMatcher<M> *Copy(bool safe = false) const override {
1332     return new ExplicitMatcher<M>(*this, safe);
1333   }
1334
1335   MatchType Type(bool test) const override { return matcher_->Type(test); }
1336
1337   void SetState(StateId s) final { matcher_->SetState(s); }
1338
1339   bool Find(Label label) final {
1340     matcher_->Find(label);
1341     CheckArc();
1342     return !Done();
1343   }
1344
1345   bool Done() const final { return matcher_->Done(); }
1346
1347   const Arc &Value() const final { return matcher_->Value(); }
1348
1349   void Next() final {
1350     matcher_->Next();
1351     CheckArc();
1352   }
1353
1354   Weight Final(StateId s) const final { return matcher_->Final(s); }
1355
1356   ssize_t Priority(StateId s) final { return matcher_->Priority(s); }
1357
1358   const FST &GetFst() const final { return matcher_->GetFst(); }
1359
1360   uint64 Properties(uint64 inprops) const override {
1361     return matcher_->Properties(inprops);
1362   }
1363
1364   const M *GetMatcher() const { return matcher_.get(); }
1365
1366   uint32 Flags() const override { return matcher_->Flags(); }
1367
1368  private:
1369   // Checks current arc if available and explicit. If not available, stops. If
1370   // not explicit, checks next ones.
1371   void CheckArc() {
1372     for (; !matcher_->Done(); matcher_->Next()) {
1373       const auto label = match_type_ == MATCH_INPUT ? matcher_->Value().ilabel
1374                                                     : matcher_->Value().olabel;
1375       if (label != kNoLabel) return;
1376     }
1377   }
1378
1379   std::unique_ptr<M> matcher_;
1380   MatchType match_type_;  // Type of match requested.
1381   bool error_;            // Error encountered?
1382 };
1383
1384 // Generic matcher, templated on the FST definition.
1385 //
1386 // Here is a typical use:
1387 //
1388 //   Matcher<StdFst> matcher(fst, MATCH_INPUT);
1389 //   matcher.SetState(state);
1390 //   if (matcher.Find(label))
1391 //     for (; !matcher.Done(); matcher.Next()) {
1392 //       auto &arc = matcher.Value();
1393 //       ...
1394 //     }
1395 template <class F>
1396 class Matcher {
1397  public:
1398   using FST = F;
1399   using Arc = typename F::Arc;
1400   using Label = typename Arc::Label;
1401   using StateId = typename Arc::StateId;
1402   using Weight = typename Arc::Weight;
1403
1404   Matcher(const FST &fst, MatchType match_type) {
1405     base_.reset(fst.InitMatcher(match_type));
1406     if (!base_) base_.reset(new SortedMatcher<FST>(fst, match_type));
1407   }
1408
1409   Matcher(const Matcher<FST> &matcher, bool safe = false) {
1410     base_.reset(matcher.base_->Copy(safe));
1411   }
1412
1413   // Takes ownership of the provided matcher.
1414   explicit Matcher(MatcherBase<Arc> *base_matcher) {
1415     base_.reset(base_matcher);
1416   }
1417
1418   Matcher<FST> *Copy(bool safe = false) const {
1419     return new Matcher<FST>(*this, safe);
1420   }
1421
1422   MatchType Type(bool test) const { return base_->Type(test); }
1423
1424   void SetState(StateId s) { base_->SetState(s); }
1425
1426   bool Find(Label label) { return base_->Find(label); }
1427
1428   bool Done() const { return base_->Done(); }
1429
1430   const Arc &Value() const { return base_->Value(); }
1431
1432   void Next() { base_->Next(); }
1433
1434   const FST &GetFst() const {
1435     return static_cast<const FST &>(base_->GetFst());
1436   }
1437
1438   uint64 Properties(uint64 props) const { return base_->Properties(props); }
1439
1440   Weight Final(StateId s) const { return base_->Final(s); }
1441
1442   uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1443
1444   ssize_t Priority(StateId s) { return base_->Priority(s); }
1445
1446  private:
1447   std::unique_ptr<MatcherBase<Arc>> base_;
1448 };
1449
1450 }  // namespace fst
1451
1452 #endif  // FST_LIB_MATCHER_H_