Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / matcher.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to allow matching labels leaving FST states.
5
6 #ifndef FST_LIB_MATCHER_H_
7 #define FST_LIB_MATCHER_H_
8
9 #include <algorithm>
10 #include <unordered_map>
11 #include <utility>
12
13 #include <fst/log.h>
14
15 #include <fst/mutable-fst.h>  // for all internal FST accessors.
16
17
18 namespace fst {
19
20 // Matchers find and iterate through requested labels at FST states. In the
21 // simplest form, these are just some associative map or search keyed on labels.
22 // More generally, they may implement matching special labels that represent
23 // sets of labels such as sigma (all), rho (rest), or phi (fail). The Matcher
24 // interface is:
25 //
26 // template <class F>
27 // class Matcher {
28 //  public:
29 //   using FST = F;
30 //   using Arc = typename FST::Arc;
31 //   using Label = typename Arc::Label;
32 //   using StateId = typename Arc::StateId;
33 //   using Weight = typename Arc::Weight;
34 //
35 //   // Required constructors.
36 //
37 //   Matcher(const FST &fst, MatchType type);
38 //   Matcher(const Matcher &matcher, bool safe = false);
39 //
40 //   // Standard copy method.
41 //   Matcher<FST> *Copy(bool safe = false) const override;
42 //
43 //   // Returns the match type that can be provided (depending on compatibility
44 //   of the input FST). It is either the requested match type, MATCH_NONE, or
45 //   MATCH_UNKNOWN. If test is false, a costly testing is avoided, but
46 //   MATCH_UNKNOWN may be returned. If test is true, a definite answer is
47 //   returned, but may involve more costly computation (e.g., visiting the FST).
48 //   MatchType Type(bool test) const override;
49 //
50 //   // Specifies the current state.
51 //   void SetState(StateId s) final;
52 //
53 //   // Finds matches to a label at the current state, returning true if a match
54 //   // found. kNoLabel matches any non-consuming transitions, e.g., epsilon
55 //   // transitions, which do not require a matching symbol.
56 //   bool Find(Label label) final;
57 //
58 //   // Iterator methods. Note that initially and after SetState() these have
59 //   undefined behavior until Find() is called.
60 //
61 //   bool Done() const final;
62 //
63 //   const Arc &Value() const final;
64 //
65 //   void Next() final;
66 //
67 //   // Returns final weight of a state.
68 //   Weight Final(StateId) const final;
69 //
70 //   // Indicates preference for being the side used for matching in
71 //   // composition. If the value is kRequirePriority, then it is
72 //   // mandatory that it be used. Calling this method without passing the
73 //   // current state of the matcher invalidates the state of the matcher.
74 //   ssize_t Priority(StateId s) final;
75 //
76 //   // This specifies the known FST properties as viewed from this matcher. It
77 //   // takes as argument the input FST's known properties.
78 //   uint64 Properties(uint64 props) const override;
79 //
80 //   // Returns matcher flags.
81 //   uint32 Flags() const override;
82 //
83 //   // Returns matcher FST.
84 //   const FST &GetFst() const override;
85 // };
86
87 // Basic matcher flags.
88
89 // Matcher needs to be used as the matching side in composition for
90 // at least one state (has kRequirePriority).
91 constexpr uint32 kRequireMatch = 0x00000001;
92
93 // Flags used for basic matchers (see also lookahead.h).
94 constexpr uint32 kMatcherFlags = kRequireMatch;
95
96 // Matcher priority that is mandatory.
97 constexpr ssize_t kRequirePriority = -1;
98
99 // Matcher interface, templated on the Arc definition; used for matcher
100 // specializations that are returned by the InitMatcher FST method.
101 template <class A>
102 class MatcherBase {
103  public:
104   using Arc = A;
105   using Label = typename Arc::Label;
106   using StateId = typename Arc::StateId;
107   using Weight = typename Arc::Weight;
108
109   virtual ~MatcherBase() {}
110
111   // Virtual interface.
112
113   virtual MatcherBase<Arc> *Copy(bool safe = false) const = 0;
114   virtual MatchType Type(bool) const = 0;
115   virtual void SetState(StateId) = 0;
116   virtual bool Find(Label) = 0;
117   virtual bool Done() const = 0;
118   virtual const Arc &Value() const = 0;
119   virtual void Next() = 0;
120   virtual const Fst<Arc> &GetFst() const = 0;
121   virtual uint64 Properties(uint64) const = 0;
122
123   // Trivial implementations that can be used by derived classes. Full
124   // devirtualization is expected for any derived class marked final.
125   virtual uint32 Flags() const { return 0; }
126
127   virtual Weight Final(StateId s) const { return internal::Final(GetFst(), s); }
128
129   virtual ssize_t Priority(StateId s) { return internal::NumArcs(GetFst(), s); }
130 };
131
132 // A matcher that expects sorted labels on the side to be matched.
133 // If match_type == MATCH_INPUT, epsilons match the implicit self-loop
134 // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
135 // actual epsilon transitions. If match_type == MATCH_OUTPUT, then
136 // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
137 template <class F>
138 class SortedMatcher : public MatcherBase<typename F::Arc> {
139  public:
140   using FST = F;
141   using Arc = typename FST::Arc;
142   using Label = typename Arc::Label;
143   using StateId = typename Arc::StateId;
144   using Weight = typename Arc::Weight;
145
146   using MatcherBase<Arc>::Flags;
147   using MatcherBase<Arc>::Properties;
148
149   // Labels >= binary_label will be searched for by binary search;
150   // o.w. linear search is used.
151   SortedMatcher(const FST &fst, MatchType match_type, Label binary_label = 1)
152       : fst_(fst.Copy()),
153         state_(kNoStateId),
154         aiter_(nullptr),
155         match_type_(match_type),
156         binary_label_(binary_label),
157         match_label_(kNoLabel),
158         narcs_(0),
159         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
160         error_(false),
161         aiter_pool_(1) {
162     switch (match_type_) {
163       case MATCH_INPUT:
164       case MATCH_NONE:
165         break;
166       case MATCH_OUTPUT:
167         std::swap(loop_.ilabel, loop_.olabel);
168         break;
169       default:
170         FSTERROR() << "SortedMatcher: Bad match type";
171         match_type_ = MATCH_NONE;
172         error_ = true;
173     }
174   }
175
176   SortedMatcher(const SortedMatcher<FST> &matcher, bool safe = false)
177       : fst_(matcher.fst_->Copy(safe)),
178         state_(kNoStateId),
179         aiter_(nullptr),
180         match_type_(matcher.match_type_),
181         binary_label_(matcher.binary_label_),
182         match_label_(kNoLabel),
183         narcs_(0),
184         loop_(matcher.loop_),
185         error_(matcher.error_),
186         aiter_pool_(1) {}
187
188   ~SortedMatcher() override { Destroy(aiter_, &aiter_pool_); }
189
190   SortedMatcher<FST> *Copy(bool safe = false) const override {
191     return new SortedMatcher<FST>(*this, safe);
192   }
193
194   MatchType Type(bool test) const override {
195     if (match_type_ == MATCH_NONE) return match_type_;
196     const auto true_prop =
197         match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
198     const auto false_prop =
199         match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
200     const auto props = fst_->Properties(true_prop | false_prop, test);
201     if (props & true_prop) {
202       return match_type_;
203     } else if (props & false_prop) {
204       return MATCH_NONE;
205     } else {
206       return MATCH_UNKNOWN;
207     }
208   }
209
210   void SetState(StateId s) final {
211     if (state_ == s) return;
212     state_ = s;
213     if (match_type_ == MATCH_NONE) {
214       FSTERROR() << "SortedMatcher: Bad match type";
215       error_ = true;
216     }
217     Destroy(aiter_, &aiter_pool_);
218     aiter_ = new (&aiter_pool_) ArcIterator<FST>(*fst_, s);
219     aiter_->SetFlags(kArcNoCache, kArcNoCache);
220     narcs_ = internal::NumArcs(*fst_, s);
221     loop_.nextstate = s;
222   }
223
224   bool Find(Label match_label) final {
225     exact_match_ = true;
226     if (error_) {
227       current_loop_ = false;
228       match_label_ = kNoLabel;
229       return false;
230     }
231     current_loop_ = match_label == 0;
232     match_label_ = match_label == kNoLabel ? 0 : match_label;
233     if (Search()) {
234       return true;
235     } else {
236       return current_loop_;
237     }
238   }
239
240   // Positions matcher to the first position where inserting match_label would
241   // maintain the sort order.
242   void LowerBound(Label label) {
243     exact_match_ = false;
244     current_loop_ = false;
245     if (error_) {
246       match_label_ = kNoLabel;
247       return;
248     }
249     match_label_ = label;
250     Search();
251   }
252
253   // After Find(), returns false if no more exact matches.
254   // After LowerBound(), returns false if no more arcs.
255   bool Done() const final {
256     if (current_loop_) return false;
257     if (aiter_->Done()) return true;
258     if (!exact_match_) return false;
259     aiter_->SetFlags(
260         match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
261         kArcValueFlags);
262     const auto label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel
263                                                   : aiter_->Value().olabel;
264     return label != match_label_;
265   }
266
267   const Arc &Value() const final {
268     if (current_loop_) return loop_;
269     aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
270     return aiter_->Value();
271   }
272
273   void Next() final {
274     if (current_loop_) {
275       current_loop_ = false;
276     } else {
277       aiter_->Next();
278     }
279   }
280
281   Weight Final(StateId s) const final {
282     return MatcherBase<Arc>::Final(s);
283   }
284
285   ssize_t Priority(StateId s) final {
286     return MatcherBase<Arc>::Priority(s);
287   }
288
289   const FST &GetFst() const override { return *fst_; }
290
291   uint64 Properties(uint64 inprops) const override {
292     return inprops | (error_ ? kError : 0);
293   }
294
295   size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
296
297  private:
298   Label GetLabel() const {
299     const auto &arc = aiter_->Value();
300     return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel;
301   }
302
303   bool BinarySearch();
304   bool LinearSearch();
305   bool Search();
306
307   std::unique_ptr<const FST> fst_;
308   StateId state_;            // Matcher state.
309   ArcIterator<FST> *aiter_;  // Iterator for current state.
310   MatchType match_type_;     // Type of match to perform.
311   Label binary_label_;       // Least label for binary search.
312   Label match_label_;        // Current label to be matched.
313   size_t narcs_;             // Current state arc count.
314   Arc loop_;                 // For non-consuming symbols.
315   bool current_loop_;        // Current arc is the implicit loop.
316   bool exact_match_;         // Exact match or lower bound?
317   bool error_;               // Error encountered?
318   MemoryPool<ArcIterator<FST>> aiter_pool_;  // Pool of arc iterators.
319 };
320
321 // Returns true iff match to match_label_, positioning arc iterator at lower
322 // bound.
323 template <class FST>
324 inline bool SortedMatcher<FST>::BinarySearch() {
325   size_t low = 0;
326   size_t high = narcs_;
327   while (low < high) {
328     const size_t mid = (low + high) / 2;
329     aiter_->Seek(mid);
330     const auto label = GetLabel();
331     if (label > match_label_) {
332       high = mid;
333     } else if (label < match_label_) {
334       low = mid + 1;
335     } else {
336       // Otherwise, search backwards for the first match.
337       for (size_t i = mid; i > low; --i) {
338         aiter_->Seek(i - 1);
339         const auto label = GetLabel();
340         if (label != match_label_) {
341           aiter_->Seek(i);
342           return true;
343         }
344       }
345       return true;
346     }
347   }
348   aiter_->Seek(low);
349   return false;
350 }
351
352 // Returns true iff match to match_label_, positioning arc iterator at lower
353 // bound.
354 template <class FST>
355 inline bool SortedMatcher<FST>::LinearSearch() {
356   for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
357     const auto label = GetLabel();
358     if (label == match_label_) return true;
359     if (label > match_label_) break;
360   }
361   return false;
362 }
363
364 // Returns true iff match to match_label_, positioning arc iterator at lower
365 // bound.
366 template <class FST>
367 inline bool SortedMatcher<FST>::Search() {
368   aiter_->SetFlags(match_type_ == MATCH_INPUT ?
369                    kArcILabelValue : kArcOLabelValue,
370                    kArcValueFlags);
371   if (match_label_ >= binary_label_) {
372     return BinarySearch();
373   } else {
374     return LinearSearch();
375   }
376 }
377
378 // A matcher that stores labels in a per-state hash table populated upon the
379 // first visit to that state. Sorting is not required. Treatment of
380 // epsilons are the same as with SortedMatcher.
381 template <class F>
382 class HashMatcher : public MatcherBase<typename F::Arc> {
383  public:
384   using FST = F;
385   using Arc = typename FST::Arc;
386   using Label = typename Arc::Label;
387   using StateId = typename Arc::StateId;
388   using Weight = typename Arc::Weight;
389
390   using MatcherBase<Arc>::Flags;
391   using MatcherBase<Arc>::Final;
392   using MatcherBase<Arc>::Priority;
393
394   HashMatcher(const FST &fst, MatchType match_type)
395       : fst_(fst.Copy()),
396         state_(kNoStateId),
397         match_type_(match_type),
398         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
399     switch (match_type_) {
400       case MATCH_INPUT:
401       case MATCH_NONE:
402         break;
403       case MATCH_OUTPUT:
404         std::swap(loop_.ilabel, loop_.olabel);
405         break;
406       default:
407         FSTERROR() << "SortedMatcher: Bad match type";
408         match_type_ = MATCH_NONE;
409         error_ = true;
410     }
411   }
412
413   HashMatcher(const HashMatcher<FST> &matcher, bool safe = false)
414       : fst_(matcher.fst_->Copy(safe)),
415         state_(kNoStateId),
416         match_type_(matcher.match_type_),
417         loop_(matcher.loop_),
418         error_(matcher.error_) {}
419
420   HashMatcher<FST> *Copy(bool safe = false) const override {
421     return new HashMatcher<FST>(*this, safe);
422   }
423
424   // The argument is ignored as there are no relevant properties to test.
425   MatchType Type(bool test) const override { return match_type_; }
426
427   void SetState(StateId s) final;
428
429   bool Find(Label label) final {
430     current_loop_ = label == 0;
431     if (label == 0) {
432       Search(label);
433       return true;
434     }
435     if (label == kNoLabel) label = 0;
436     return Search(label);
437   }
438
439   bool Done() const final {
440     if (current_loop_) return false;
441     return label_it_ == label_end_;
442   }
443
444   const Arc &Value() const final {
445     if (current_loop_) return loop_;
446     aiter_->Seek(label_it_->second);
447     return aiter_->Value();
448   }
449
450   void Next() final {
451     if (current_loop_) {
452       current_loop_ = false;
453     } else {
454       ++label_it_;
455     }
456   }
457
458   const FST &GetFst() const override { return *fst_; }
459
460   uint64 Properties(uint64 inprops) const override {
461     return inprops | (error_ ? kError : 0);
462   }
463
464  private:
465   bool Search(Label match_label);
466
467   using LabelTable = std::unordered_multimap<Label, size_t>;
468   using StateTable = std::unordered_map<StateId, LabelTable>;
469
470   std::unique_ptr<const FST> fst_;
471   StateId state_;  // Matcher state.
472   MatchType match_type_;
473   Arc loop_;           // The implicit loop itself.
474   bool current_loop_;  // Is the current arc is the implicit loop?
475   bool error_;         // Error encountered?
476   std::unique_ptr<ArcIterator<FST>> aiter_;
477   StateTable state_table_;   // Table from states to label table.
478   LabelTable *label_table_;  // Pointer to current state's label table.
479   typename LabelTable::iterator label_it_;   // Position for label.
480   typename LabelTable::iterator label_end_;  // Position for last label + 1.
481 };
482
483 template <class FST>
484 void HashMatcher<FST>::SetState(typename FST::Arc::StateId s) {
485   if (state_ == s) return;
486   // Resets everything for the state.
487   state_ = s;
488   loop_.nextstate = state_;
489   aiter_.reset(new ArcIterator<FST>(*fst_, state_));
490   if (match_type_ == MATCH_NONE) {
491     FSTERROR() << "HashMatcher: Bad match type";
492     error_ = false;
493   }
494   // Attempts to insert a new label table; if it already exists,
495   // no additional work is done and we simply return.
496   auto it_and_success = state_table_.emplace(state_, LabelTable());
497   if (!it_and_success.second) return;
498   // Otherwise, populate this new table.
499   // Sets instance's pointer to the label table for this state.
500   label_table_ = &(it_and_success.first->second);
501   // Populates the label table.
502   label_table_->reserve(internal::NumArcs(*fst_, state_));
503   const auto aiter_flags =
504       (match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue) |
505       kArcNoCache;
506   aiter_->SetFlags(aiter_flags, kArcFlags);
507   for (; !aiter_->Done(); aiter_->Next()) {
508     const auto label = (match_type_ == MATCH_INPUT) ? aiter_->Value().ilabel
509                                                     : aiter_->Value().olabel;
510     label_table_->emplace(label, aiter_->Position());
511   }
512   aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
513 }
514
515 template <class FST>
516 inline bool HashMatcher<FST>::Search(typename FST::Arc::Label match_label) {
517   auto range = label_table_->equal_range(match_label);
518   if (range.first == range.second) return false;
519   label_it_ = range.first;
520   label_end_ = range.second;
521   aiter_->Seek(label_it_->second);
522   return true;
523 }
524
525 // Specifies whether we rewrite both the input and output sides during matching.
526 enum MatcherRewriteMode {
527   MATCHER_REWRITE_AUTO = 0,  // Rewrites both sides iff acceptor.
528   MATCHER_REWRITE_ALWAYS,
529   MATCHER_REWRITE_NEVER
530 };
531
532 // For any requested label that doesn't match at a state, this matcher
533 // considers the *unique* transition that matches the label 'phi_label'
534 // (phi = 'fail'), and recursively looks for a match at its
535 // destination.  When 'phi_loop' is true, if no match is found but a
536 // phi self-loop is found, then the phi transition found is returned
537 // with the phi_label rewritten as the requested label (both sides if
538 // an acceptor, or if 'rewrite_both' is true and both input and output
539 // labels of the found transition are 'phi_label').  If 'phi_label' is
540 // kNoLabel, this special matching is not done.  PhiMatcher is
541 // templated itself on a matcher, which is used to perform the
542 // underlying matching.  By default, the underlying matcher is
543 // constructed by PhiMatcher. The user can instead pass in this
544 // object; in that case, PhiMatcher takes its ownership.
545 // Phi non-determinism not supported. No non-consuming symbols other
546 // than epsilon supported with the underlying template argument matcher.
547 template <class M>
548 class PhiMatcher : public MatcherBase<typename M::Arc> {
549  public:
550   using FST = typename M::FST;
551   using Arc = typename FST::Arc;
552   using Label = typename Arc::Label;
553   using StateId = typename Arc::StateId;
554   using Weight = typename Arc::Weight;
555
556   PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel,
557              bool phi_loop = true,
558              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
559              M *matcher = nullptr)
560       : matcher_(matcher ? matcher : new M(fst, match_type)),
561         match_type_(match_type),
562         phi_label_(phi_label),
563         state_(kNoStateId),
564         phi_loop_(phi_loop),
565         error_(false) {
566     if (match_type == MATCH_BOTH) {
567       FSTERROR() << "PhiMatcher: Bad match type";
568       match_type_ = MATCH_NONE;
569       error_ = true;
570     }
571     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
572       rewrite_both_ = fst.Properties(kAcceptor, true);
573     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
574       rewrite_both_ = true;
575     } else {
576       rewrite_both_ = false;
577     }
578   }
579
580   PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false)
581       : matcher_(new M(*matcher.matcher_, safe)),
582         match_type_(matcher.match_type_),
583         phi_label_(matcher.phi_label_),
584         rewrite_both_(matcher.rewrite_both_),
585         state_(kNoStateId),
586         phi_loop_(matcher.phi_loop_),
587         error_(matcher.error_) {}
588
589   PhiMatcher<M> *Copy(bool safe = false) const override {
590     return new PhiMatcher<M>(*this, safe);
591   }
592
593   MatchType Type(bool test) const override { return matcher_->Type(test); }
594
595   void SetState(StateId s) final {
596     if (state_ == s) return;
597     matcher_->SetState(s);
598     state_ = s;
599     has_phi_ = phi_label_ != kNoLabel;
600   }
601
602   bool Find(Label match_label) final;
603
604   bool Done() const final { return matcher_->Done(); }
605
606   const Arc &Value() const final {
607     if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
608       return matcher_->Value();
609     } else if (phi_match_ == 0) {  // Virtual epsilon loop.
610       phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
611       if (match_type_ == MATCH_OUTPUT) {
612         std::swap(phi_arc_.ilabel, phi_arc_.olabel);
613       }
614       return phi_arc_;
615     } else {
616       phi_arc_ = matcher_->Value();
617       phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
618       if (phi_match_ != kNoLabel) {  // Phi loop match.
619         if (rewrite_both_) {
620           if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_;
621           if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_;
622         } else if (match_type_ == MATCH_INPUT) {
623           phi_arc_.ilabel = phi_match_;
624         } else {
625           phi_arc_.olabel = phi_match_;
626         }
627       }
628       return phi_arc_;
629     }
630   }
631
632   void Next() final { matcher_->Next(); }
633
634   Weight Final(StateId s) const final {
635     auto weight = matcher_->Final(s);
636     if (phi_label_ == kNoLabel || weight != Weight::Zero()) {
637       return weight;
638     }
639     weight = Weight::One();
640     matcher_->SetState(s);
641     while (matcher_->Final(s) == Weight::Zero()) {
642       if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) break;
643       weight = Times(weight, matcher_->Value().weight);
644       if (s == matcher_->Value().nextstate) {
645         return Weight::Zero();  // Does not follow phi self-loops.
646       }
647       s = matcher_->Value().nextstate;
648       matcher_->SetState(s);
649     }
650     weight = Times(weight, matcher_->Final(s));
651     return weight;
652   }
653
654   ssize_t Priority(StateId s) final {
655     if (phi_label_ != kNoLabel) {
656       matcher_->SetState(s);
657       const bool has_phi = matcher_->Find(phi_label_ == 0 ? -1 : phi_label_);
658       return has_phi ? kRequirePriority : matcher_->Priority(s);
659     } else {
660       return matcher_->Priority(s);
661     }
662   }
663
664   const FST &GetFst() const override { return matcher_->GetFst(); }
665
666   uint64 Properties(uint64 props) const override;
667
668   uint32 Flags() const override {
669     if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) {
670       return matcher_->Flags();
671     }
672     return matcher_->Flags() | kRequireMatch;
673   }
674
675   Label PhiLabel() const { return phi_label_; }
676
677  private:
678   mutable std::unique_ptr<M> matcher_;
679   MatchType match_type_;  // Type of match requested.
680   Label phi_label_;       // Label that represents the phi transition.
681   bool rewrite_both_;     // Rewrite both sides when both are phi_label_?
682   bool has_phi_;          // Are there possibly phis at the current state?
683   Label phi_match_;       // Current label that matches phi loop.
684   mutable Arc phi_arc_;   // Arc to return.
685   StateId state_;         // Matcher state.
686   Weight phi_weight_;     // Product of the weights of phi transitions taken.
687   bool phi_loop_;         // When true, phi self-loop are allowed and treated
688                           // as rho (required for Aho-Corasick).
689   bool error_;            // Error encountered?
690
691   PhiMatcher &operator=(const PhiMatcher &) = delete;
692 };
693
694 template <class M>
695 inline bool PhiMatcher<M>::Find(Label label) {
696   if (label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
697     FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
698     error_ = true;
699     return false;
700   }
701   matcher_->SetState(state_);
702   phi_match_ = kNoLabel;
703   phi_weight_ = Weight::One();
704   // If phi_label_ == 0, there are no more true epsilon arcs.
705   if (phi_label_ == 0) {
706     if (label == kNoLabel) {
707       return false;
708     }
709     if (label == 0) {  // but a virtual epsilon loop needs to be returned.
710       if (!matcher_->Find(kNoLabel)) {
711         return matcher_->Find(0);
712       } else {
713         phi_match_ = 0;
714         return true;
715       }
716     }
717   }
718   if (!has_phi_ || label == 0 || label == kNoLabel) {
719     return matcher_->Find(label);
720   }
721   auto s = state_;
722   while (!matcher_->Find(label)) {
723     // Look for phi transition (if phi_label_ == 0, we need to look
724     // for -1 to avoid getting the virtual self-loop)
725     if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false;
726     if (phi_loop_ && matcher_->Value().nextstate == s) {
727       phi_match_ = label;
728       return true;
729     }
730     phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
731     s = matcher_->Value().nextstate;
732     matcher_->Next();
733     if (!matcher_->Done()) {
734       FSTERROR() << "PhiMatcher: Phi non-determinism not supported";
735       error_ = true;
736     }
737     matcher_->SetState(s);
738   }
739   return true;
740 }
741
742 template <class M>
743 inline uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
744   auto outprops = matcher_->Properties(inprops);
745   if (error_) outprops |= kError;
746   if (match_type_ == MATCH_NONE) {
747     return outprops;
748   } else if (match_type_ == MATCH_INPUT) {
749     if (phi_label_ == 0) {
750       outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
751       outprops |= kNoEpsilons | kNoIEpsilons;
752     }
753     if (rewrite_both_) {
754       return outprops &
755              ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
756                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
757     } else {
758       return outprops &
759              ~(kODeterministic | kAcceptor | kString | kILabelSorted |
760                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
761     }
762   } else if (match_type_ == MATCH_OUTPUT) {
763     if (phi_label_ == 0) {
764       outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
765       outprops |= kNoEpsilons | kNoOEpsilons;
766     }
767     if (rewrite_both_) {
768       return outprops &
769              ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
770                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
771     } else {
772       return outprops &
773              ~(kIDeterministic | kAcceptor | kString | kILabelSorted |
774                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
775     }
776   } else {
777     // Shouldn't ever get here.
778     FSTERROR() << "PhiMatcher: Bad match type: " << match_type_;
779     return 0;
780   }
781 }
782
783 // For any requested label that doesn't match at a state, this matcher
784 // considers all transitions that match the label 'rho_label' (rho =
785 // 'rest').  Each such rho transition found is returned with the
786 // rho_label rewritten as the requested label (both sides if an
787 // acceptor, or if 'rewrite_both' is true and both input and output
788 // labels of the found transition are 'rho_label').  If 'rho_label' is
789 // kNoLabel, this special matching is not done.  RhoMatcher is
790 // templated itself on a matcher, which is used to perform the
791 // underlying matching.  By default, the underlying matcher is
792 // constructed by RhoMatcher.  The user can instead pass in this
793 // object; in that case, RhoMatcher takes its ownership.
794 // No non-consuming symbols other than epsilon supported with
795 // the underlying template argument matcher.
796 template <class M>
797 class RhoMatcher : public MatcherBase<typename M::Arc> {
798  public:
799   using FST = typename M::FST;
800   using Arc = typename FST::Arc;
801   using Label = typename Arc::Label;
802   using StateId = typename Arc::StateId;
803   using Weight = typename Arc::Weight;
804
805   RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel,
806              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
807              M *matcher = nullptr)
808       : matcher_(matcher ? matcher : new M(fst, match_type)),
809         match_type_(match_type),
810         rho_label_(rho_label),
811         error_(false),
812         state_(kNoStateId) {
813     if (match_type == MATCH_BOTH) {
814       FSTERROR() << "RhoMatcher: Bad match type";
815       match_type_ = MATCH_NONE;
816       error_ = true;
817     }
818     if (rho_label == 0) {
819       FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
820       rho_label_ = kNoLabel;
821       error_ = true;
822     }
823     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
824       rewrite_both_ = fst.Properties(kAcceptor, true);
825     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
826       rewrite_both_ = true;
827     } else {
828       rewrite_both_ = false;
829     }
830   }
831
832   RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false)
833       : matcher_(new M(*matcher.matcher_, safe)),
834         match_type_(matcher.match_type_),
835         rho_label_(matcher.rho_label_),
836         rewrite_both_(matcher.rewrite_both_),
837         error_(matcher.error_),
838         state_(kNoStateId) {}
839
840   RhoMatcher<M> *Copy(bool safe = false) const override {
841     return new RhoMatcher<M>(*this, safe);
842   }
843
844   MatchType Type(bool test) const override { return matcher_->Type(test); }
845
846   void SetState(StateId s) final {
847     if (state_ == s) return;
848     state_ = s;
849     matcher_->SetState(s);
850     has_rho_ = rho_label_ != kNoLabel;
851   }
852
853   bool Find(Label label) final {
854     if (label == rho_label_ && rho_label_ != kNoLabel) {
855       FSTERROR() << "RhoMatcher::Find: bad label (rho)";
856       error_ = true;
857       return false;
858     }
859     if (matcher_->Find(label)) {
860       rho_match_ = kNoLabel;
861       return true;
862     } else if (has_rho_ && label != 0 && label != kNoLabel &&
863                (has_rho_ = matcher_->Find(rho_label_))) {
864       rho_match_ = label;
865       return true;
866     } else {
867       return false;
868     }
869   }
870
871   bool Done() const final { return matcher_->Done(); }
872
873   const Arc &Value() const final {
874     if (rho_match_ == kNoLabel) {
875       return matcher_->Value();
876     } else {
877       rho_arc_ = matcher_->Value();
878       if (rewrite_both_) {
879         if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_;
880         if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_;
881       } else if (match_type_ == MATCH_INPUT) {
882         rho_arc_.ilabel = rho_match_;
883       } else {
884         rho_arc_.olabel = rho_match_;
885       }
886       return rho_arc_;
887     }
888   }
889
890   void Next() final { matcher_->Next(); }
891
892   Weight Final(StateId s) const final { return matcher_->Final(s); }
893
894   ssize_t Priority(StateId s) final {
895     state_ = s;
896     matcher_->SetState(s);
897     has_rho_ = matcher_->Find(rho_label_);
898     if (has_rho_) {
899       return kRequirePriority;
900     } else {
901       return matcher_->Priority(s);
902     }
903   }
904
905   const FST &GetFst() const override { return matcher_->GetFst(); }
906
907   uint64 Properties(uint64 props) const override;
908
909   uint32 Flags() const override {
910     if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) {
911       return matcher_->Flags();
912     }
913     return matcher_->Flags() | kRequireMatch;
914   }
915
916   Label RhoLabel() const { return rho_label_; }
917
918  private:
919   std::unique_ptr<M> matcher_;
920   MatchType match_type_;  // Type of match requested.
921   Label rho_label_;       // Label that represents the rho transition
922   bool rewrite_both_;     // Rewrite both sides when both are rho_label_?
923   bool has_rho_;          // Are there possibly rhos at the current state?
924   Label rho_match_;       // Current label that matches rho transition.
925   mutable Arc rho_arc_;   // Arc to return when rho match.
926   bool error_;            // Error encountered?
927   StateId state_;         // Matcher state.
928 };
929
930 template <class M>
931 inline uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
932   auto outprops = matcher_->Properties(inprops);
933   if (error_) outprops |= kError;
934   if (match_type_ == MATCH_NONE) {
935     return outprops;
936   } else if (match_type_ == MATCH_INPUT) {
937     if (rewrite_both_) {
938       return outprops &
939              ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
940                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
941     } else {
942       return outprops &
943              ~(kODeterministic | kAcceptor | kString | kILabelSorted |
944                kNotILabelSorted);
945     }
946   } else if (match_type_ == MATCH_OUTPUT) {
947     if (rewrite_both_) {
948       return outprops &
949              ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
950                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
951     } else {
952       return outprops &
953              ~(kIDeterministic | kAcceptor | kString | kOLabelSorted |
954                kNotOLabelSorted);
955     }
956   } else {
957     // Shouldn't ever get here.
958     FSTERROR() << "RhoMatcher: Bad match type: " << match_type_;
959     return 0;
960   }
961 }
962
963 // For any requested label, this matcher considers all transitions
964 // that match the label 'sigma_label' (sigma = "any"), and this in
965 // additions to transitions with the requested label.  Each such sigma
966 // transition found is returned with the sigma_label rewritten as the
967 // requested label (both sides if an acceptor, or if 'rewrite_both' is
968 // true and both input and output labels of the found transition are
969 // 'sigma_label').  If 'sigma_label' is kNoLabel, this special
970 // matching is not done.  SigmaMatcher is templated itself on a
971 // matcher, which is used to perform the underlying matching.  By
972 // default, the underlying matcher is constructed by SigmaMatcher.
973 // The user can instead pass in this object; in that case,
974 // SigmaMatcher takes its ownership.  No non-consuming symbols other
975 // than epsilon supported with the underlying template argument matcher.
976 template <class M>
977 class SigmaMatcher : public MatcherBase<typename M::Arc> {
978  public:
979   using FST = typename M::FST;
980   using Arc = typename FST::Arc;
981   using Label = typename Arc::Label;
982   using StateId = typename Arc::StateId;
983   using Weight = typename Arc::Weight;
984
985   SigmaMatcher(const FST &fst, MatchType match_type,
986                Label sigma_label = kNoLabel,
987                MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
988                M *matcher = nullptr)
989       : matcher_(matcher ? matcher : new M(fst, match_type)),
990         match_type_(match_type),
991         sigma_label_(sigma_label),
992         error_(false),
993         state_(kNoStateId) {
994     if (match_type == MATCH_BOTH) {
995       FSTERROR() << "SigmaMatcher: Bad match type";
996       match_type_ = MATCH_NONE;
997       error_ = true;
998     }
999     if (sigma_label == 0) {
1000       FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
1001       sigma_label_ = kNoLabel;
1002       error_ = true;
1003     }
1004     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
1005       rewrite_both_ = fst.Properties(kAcceptor, true);
1006     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
1007       rewrite_both_ = true;
1008     } else {
1009       rewrite_both_ = false;
1010     }
1011   }
1012
1013   SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false)
1014       : matcher_(new M(*matcher.matcher_, safe)),
1015         match_type_(matcher.match_type_),
1016         sigma_label_(matcher.sigma_label_),
1017         rewrite_both_(matcher.rewrite_both_),
1018         error_(matcher.error_),
1019         state_(kNoStateId) {}
1020
1021   SigmaMatcher<M> *Copy(bool safe = false) const override {
1022     return new SigmaMatcher<M>(*this, safe);
1023   }
1024
1025   MatchType Type(bool test) const override { return matcher_->Type(test); }
1026
1027   void SetState(StateId s) final {
1028     if (state_ == s) return;
1029     state_ = s;
1030     matcher_->SetState(s);
1031     has_sigma_ =
1032         (sigma_label_ != kNoLabel) ? matcher_->Find(sigma_label_) : false;
1033   }
1034
1035   bool Find(Label match_label) final {
1036     match_label_ = match_label;
1037     if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
1038       FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
1039       error_ = true;
1040       return false;
1041     }
1042     if (matcher_->Find(match_label)) {
1043       sigma_match_ = kNoLabel;
1044       return true;
1045     } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
1046                matcher_->Find(sigma_label_)) {
1047       sigma_match_ = match_label;
1048       return true;
1049     } else {
1050       return false;
1051     }
1052   }
1053
1054   bool Done() const final { return matcher_->Done(); }
1055
1056   const Arc &Value() const final {
1057     if (sigma_match_ == kNoLabel) {
1058       return matcher_->Value();
1059     } else {
1060       sigma_arc_ = matcher_->Value();
1061       if (rewrite_both_) {
1062         if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_;
1063         if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_;
1064       } else if (match_type_ == MATCH_INPUT) {
1065         sigma_arc_.ilabel = sigma_match_;
1066       } else {
1067         sigma_arc_.olabel = sigma_match_;
1068       }
1069       return sigma_arc_;
1070     }
1071   }
1072
1073   void Next() final {
1074     matcher_->Next();
1075     if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
1076         (match_label_ > 0)) {
1077       matcher_->Find(sigma_label_);
1078       sigma_match_ = match_label_;
1079     }
1080   }
1081
1082   Weight Final(StateId s) const final { return matcher_->Final(s); }
1083
1084   ssize_t Priority(StateId s) final {
1085     if (sigma_label_ != kNoLabel) {
1086       SetState(s);
1087       return has_sigma_ ? kRequirePriority : matcher_->Priority(s);
1088     } else {
1089       return matcher_->Priority(s);
1090     }
1091   }
1092
1093   const FST &GetFst() const override { return matcher_->GetFst(); }
1094
1095   uint64 Properties(uint64 props) const override;
1096
1097   uint32 Flags() const override {
1098     if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) {
1099       return matcher_->Flags();
1100     }
1101     return matcher_->Flags() | kRequireMatch;
1102   }
1103
1104   Label SigmaLabel() const { return sigma_label_; }
1105
1106  private:
1107   std::unique_ptr<M> matcher_;
1108   MatchType match_type_;   // Type of match requested.
1109   Label sigma_label_;      // Label that represents the sigma transition.
1110   bool rewrite_both_;      // Rewrite both sides when both are sigma_label_?
1111   bool has_sigma_;         // Are there sigmas at the current state?
1112   Label sigma_match_;      // Current label that matches sigma transition.
1113   mutable Arc sigma_arc_;  // Arc to return when sigma match.
1114   Label match_label_;      // Label being matched.
1115   bool error_;             // Error encountered?
1116   StateId state_;          // Matcher state.
1117 };
1118
1119 template <class M>
1120 inline uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
1121   auto outprops = matcher_->Properties(inprops);
1122   if (error_) outprops |= kError;
1123   if (match_type_ == MATCH_NONE) {
1124     return outprops;
1125   } else if (rewrite_both_) {
1126     return outprops &
1127            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1128              kNonODeterministic | kILabelSorted | kNotILabelSorted |
1129              kOLabelSorted | kNotOLabelSorted | kString);
1130   } else if (match_type_ == MATCH_INPUT) {
1131     return outprops &
1132            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1133              kNonODeterministic | kILabelSorted | kNotILabelSorted | kString |
1134              kAcceptor);
1135   } else if (match_type_ == MATCH_OUTPUT) {
1136     return outprops &
1137            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1138              kNonODeterministic | kOLabelSorted | kNotOLabelSorted | kString |
1139              kAcceptor);
1140   } else {
1141     // Shouldn't ever get here.
1142     FSTERROR() << "SigmaMatcher: Bad match type: " << match_type_;
1143     return 0;
1144   }
1145 }
1146
1147 // Flags for MultiEpsMatcher.
1148
1149 // Return multi-epsilon arcs for Find(kNoLabel).
1150 const uint32 kMultiEpsList = 0x00000001;
1151
1152 // Return a kNolabel loop for Find(multi_eps).
1153 const uint32 kMultiEpsLoop = 0x00000002;
1154
1155 // MultiEpsMatcher: allows treating multiple non-0 labels as
1156 // non-consuming labels in addition to 0 that is always
1157 // non-consuming. Precise behavior controlled by 'flags' argument. By
1158 // default, the underlying matcher is constructed by
1159 // MultiEpsMatcher. The user can instead pass in this object; in that
1160 // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
1161 // true.
1162 template <class M>
1163 class MultiEpsMatcher {
1164  public:
1165   using FST = typename M::FST;
1166   using Arc = typename FST::Arc;
1167   using Label = typename Arc::Label;
1168   using StateId = typename Arc::StateId;
1169   using Weight = typename Arc::Weight;
1170
1171   MultiEpsMatcher(const FST &fst, MatchType match_type,
1172                   uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1173                   M *matcher = nullptr, bool own_matcher = true)
1174       : matcher_(matcher ? matcher : new M(fst, match_type)),
1175         flags_(flags),
1176         own_matcher_(matcher ? own_matcher : true) {
1177     if (match_type == MATCH_INPUT) {
1178       loop_.ilabel = kNoLabel;
1179       loop_.olabel = 0;
1180     } else {
1181       loop_.ilabel = 0;
1182       loop_.olabel = kNoLabel;
1183     }
1184     loop_.weight = Weight::One();
1185     loop_.nextstate = kNoStateId;
1186   }
1187
1188   MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false)
1189       : matcher_(new M(*matcher.matcher_, safe)),
1190         flags_(matcher.flags_),
1191         own_matcher_(true),
1192         multi_eps_labels_(matcher.multi_eps_labels_),
1193         loop_(matcher.loop_) {
1194     loop_.nextstate = kNoStateId;
1195   }
1196
1197   ~MultiEpsMatcher() {
1198     if (own_matcher_) delete matcher_;
1199   }
1200
1201   MultiEpsMatcher<M> *Copy(bool safe = false) const {
1202     return new MultiEpsMatcher<M>(*this, safe);
1203   }
1204
1205   MatchType Type(bool test) const { return matcher_->Type(test); }
1206
1207   void SetState(StateId state) {
1208     matcher_->SetState(state);
1209     loop_.nextstate = state;
1210   }
1211
1212   bool Find(Label label);
1213
1214   bool Done() const { return done_; }
1215
1216   const Arc &Value() const { return current_loop_ ? loop_ : matcher_->Value(); }
1217
1218   void Next() {
1219     if (!current_loop_) {
1220       matcher_->Next();
1221       done_ = matcher_->Done();
1222       if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
1223         ++multi_eps_iter_;
1224         while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1225                !matcher_->Find(*multi_eps_iter_)) {
1226           ++multi_eps_iter_;
1227         }
1228         if (multi_eps_iter_ != multi_eps_labels_.End()) {
1229           done_ = false;
1230         } else {
1231           done_ = !matcher_->Find(kNoLabel);
1232         }
1233       }
1234     } else {
1235       done_ = true;
1236     }
1237   }
1238
1239   const FST &GetFst() const { return matcher_->GetFst(); }
1240
1241   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
1242
1243   const M *GetMatcher() const { return matcher_; }
1244
1245   Weight Final(StateId s) const { return matcher_->Final(s); }
1246
1247   uint32 Flags() const { return matcher_->Flags(); }
1248
1249   ssize_t Priority(StateId s) { return matcher_->Priority(s); }
1250
1251   void AddMultiEpsLabel(Label label) {
1252     if (label == 0) {
1253       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1254     } else {
1255       multi_eps_labels_.Insert(label);
1256     }
1257   }
1258
1259   void RemoveMultiEpsLabel(Label label) {
1260     if (label == 0) {
1261       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1262     } else {
1263       multi_eps_labels_.Erase(label);
1264     }
1265   }
1266
1267   void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); }
1268
1269  private:
1270   M *matcher_;
1271   uint32 flags_;
1272   bool own_matcher_;  // Does this class delete the matcher?
1273
1274   // Multi-eps label set.
1275   CompactSet<Label, kNoLabel> multi_eps_labels_;
1276   typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1277
1278   bool current_loop_;  // Current arc is the implicit loop?
1279   mutable Arc loop_;   // For non-consuming symbols.
1280   bool done_;          // Matching done?
1281
1282   MultiEpsMatcher &operator=(const MultiEpsMatcher &) = delete;
1283 };
1284
1285 template <class M>
1286 inline bool MultiEpsMatcher<M>::Find(Label label) {
1287   multi_eps_iter_ = multi_eps_labels_.End();
1288   current_loop_ = false;
1289   bool ret;
1290   if (label == 0) {
1291     ret = matcher_->Find(0);
1292   } else if (label == kNoLabel) {
1293     if (flags_ & kMultiEpsList) {
1294       // Returns all non-consuming arcs (including epsilon).
1295       multi_eps_iter_ = multi_eps_labels_.Begin();
1296       while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1297              !matcher_->Find(*multi_eps_iter_)) {
1298         ++multi_eps_iter_;
1299       }
1300       if (multi_eps_iter_ != multi_eps_labels_.End()) {
1301         ret = true;
1302       } else {
1303         ret = matcher_->Find(kNoLabel);
1304       }
1305     } else {
1306       // Returns all epsilon arcs.
1307       ret = matcher_->Find(kNoLabel);
1308     }
1309   } else if ((flags_ & kMultiEpsLoop) &&
1310              multi_eps_labels_.Find(label) != multi_eps_labels_.End()) {
1311     // Returns implicit loop.
1312     current_loop_ = true;
1313     ret = true;
1314   } else {
1315     ret = matcher_->Find(label);
1316   }
1317   done_ = !ret;
1318   return ret;
1319 }
1320
1321 // This class discards any implicit matches (e.g., the implicit epsilon
1322 // self-loops in the SortedMatcher). Matchers are most often used in
1323 // composition/intersection where the implicit matches are needed
1324 // e.g. for epsilon processing. However, if a matcher is simply being
1325 // used to look-up explicit label matches, this class saves the user
1326 // from having to check for and discard the unwanted implicit matches
1327 // themselves.
1328 template <class M>
1329 class ExplicitMatcher : public MatcherBase<typename M::Arc> {
1330  public:
1331   using FST = typename M::FST;
1332   using Arc = typename FST::Arc;
1333   using Label = typename Arc::Label;
1334   using StateId = typename Arc::StateId;
1335   using Weight = typename Arc::Weight;
1336
1337   ExplicitMatcher(const FST &fst, MatchType match_type, M *matcher = nullptr)
1338       : matcher_(matcher ? matcher : new M(fst, match_type)),
1339         match_type_(match_type),
1340         error_(false) {}
1341
1342   ExplicitMatcher(const ExplicitMatcher<M> &matcher, bool safe = false)
1343       : matcher_(new M(*matcher.matcher_, safe)),
1344         match_type_(matcher.match_type_),
1345         error_(matcher.error_) {}
1346
1347   ExplicitMatcher<M> *Copy(bool safe = false) const override {
1348     return new ExplicitMatcher<M>(*this, safe);
1349   }
1350
1351   MatchType Type(bool test) const override { return matcher_->Type(test); }
1352
1353   void SetState(StateId s) final { matcher_->SetState(s); }
1354
1355   bool Find(Label label) final {
1356     matcher_->Find(label);
1357     CheckArc();
1358     return !Done();
1359   }
1360
1361   bool Done() const final { return matcher_->Done(); }
1362
1363   const Arc &Value() const final { return matcher_->Value(); }
1364
1365   void Next() final {
1366     matcher_->Next();
1367     CheckArc();
1368   }
1369
1370   Weight Final(StateId s) const final { return matcher_->Final(s); }
1371
1372   ssize_t Priority(StateId s) final { return matcher_->Priority(s); }
1373
1374   const FST &GetFst() const final { return matcher_->GetFst(); }
1375
1376   uint64 Properties(uint64 inprops) const override {
1377     return matcher_->Properties(inprops);
1378   }
1379
1380   const M *GetMatcher() const { return matcher_.get(); }
1381
1382   uint32 Flags() const override { return matcher_->Flags(); }
1383
1384  private:
1385   // Checks current arc if available and explicit. If not available, stops. If
1386   // not explicit, checks next ones.
1387   void CheckArc() {
1388     for (; !matcher_->Done(); matcher_->Next()) {
1389       const auto label = match_type_ == MATCH_INPUT ? matcher_->Value().ilabel
1390                                                     : matcher_->Value().olabel;
1391       if (label != kNoLabel) return;
1392     }
1393   }
1394
1395   std::unique_ptr<M> matcher_;
1396   MatchType match_type_;  // Type of match requested.
1397   bool error_;            // Error encountered?
1398 };
1399
1400 // Generic matcher, templated on the FST definition.
1401 //
1402 // Here is a typical use:
1403 //
1404 //   Matcher<StdFst> matcher(fst, MATCH_INPUT);
1405 //   matcher.SetState(state);
1406 //   if (matcher.Find(label))
1407 //     for (; !matcher.Done(); matcher.Next()) {
1408 //       auto &arc = matcher.Value();
1409 //       ...
1410 //     }
1411 template <class F>
1412 class Matcher {
1413  public:
1414   using FST = F;
1415   using Arc = typename F::Arc;
1416   using Label = typename Arc::Label;
1417   using StateId = typename Arc::StateId;
1418   using Weight = typename Arc::Weight;
1419
1420   Matcher(const FST &fst, MatchType match_type) {
1421     base_.reset(fst.InitMatcher(match_type));
1422     if (!base_) base_.reset(new SortedMatcher<FST>(fst, match_type));
1423   }
1424
1425   Matcher(const Matcher<FST> &matcher, bool safe = false) {
1426     base_.reset(matcher.base_->Copy(safe));
1427   }
1428
1429   // Takes ownership of the provided matcher.
1430   explicit Matcher(MatcherBase<Arc> *base_matcher) {
1431     base_.reset(base_matcher);
1432   }
1433
1434   Matcher<FST> *Copy(bool safe = false) const {
1435     return new Matcher<FST>(*this, safe);
1436   }
1437
1438   MatchType Type(bool test) const { return base_->Type(test); }
1439
1440   void SetState(StateId s) { base_->SetState(s); }
1441
1442   bool Find(Label label) { return base_->Find(label); }
1443
1444   bool Done() const { return base_->Done(); }
1445
1446   const Arc &Value() const { return base_->Value(); }
1447
1448   void Next() { base_->Next(); }
1449
1450   const FST &GetFst() const {
1451     return static_cast<const FST &>(base_->GetFst());
1452   }
1453
1454   uint64 Properties(uint64 props) const { return base_->Properties(props); }
1455
1456   Weight Final(StateId s) const { return base_->Final(s); }
1457
1458   uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1459
1460   ssize_t Priority(StateId s) { return base_->Priority(s); }
1461
1462  private:
1463   std::unique_ptr<MatcherBase<Arc>> base_;
1464 };
1465
1466 }  // namespace fst
1467
1468 #endif  // FST_LIB_MATCHER_H_