975036ac101589c1c52f330ecb6d9a8587107bc1
[platform/upstream/openfst.git] / src / include / fst / matcher.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to allow matching labels leaving FST states.
5
6 #ifndef FST_MATCHER_H_
7 #define FST_MATCHER_H_
8
9 #include <algorithm>
10 #include <unordered_map>
11 #include <utility>
12
13 #include <fst/log.h>
14
15 #include <fst/mutable-fst.h>  // for all internal FST accessors.
16
17
18 namespace fst {
19
20 // Matchers find and iterate through requested labels at FST states. In the
21 // simplest form, these are just some associative map or search keyed on labels.
22 // More generally, they may implement matching special labels that represent
23 // sets of labels such as sigma (all), rho (rest), or phi (fail). The Matcher
24 // interface is:
25 //
26 // template <class F>
27 // class Matcher {
28 //  public:
29 //   using FST = F;
30 //   using Arc = typename FST::Arc;
31 //   using Label = typename Arc::Label;
32 //   using StateId = typename Arc::StateId;
33 //   using Weight = typename Arc::Weight;
34 //
35 //   // Required constructors.
36 //
37 //   Matcher(const FST &fst, MatchType type);
38 //   Matcher(const Matcher &matcher, bool safe = false);
39 //
40 //   // Standard copy method.
41 //   Matcher<FST> *Copy(bool safe = false) const override;
42 //
43 //   // Returns the match type that can be provided (depending on compatibility
44 //   of the input FST). It is either the requested match type, MATCH_NONE, or
45 //   MATCH_UNKNOWN. If test is false, a costly testing is avoided, but
46 //   MATCH_UNKNOWN may be returned. If test is true, a definite answer is
47 //   returned, but may involve more costly computation (e.g., visiting the FST).
48 //   MatchType Type(bool test) const override;
49 //
50 //   // Specifies the current state.
51 //   void SetState(StateId s) final;
52 //
53 //   // Finds matches to a label at the current state, returning true if a match
54 //   // found. kNoLabel matches any non-consuming transitions, e.g., epsilon
55 //   // transitions, which do not require a matching symbol.
56 //   bool Find(Label label) final;
57 //
58 //   // Iterator methods. Note that initially and after SetState() these have
59 //   undefined behavior until Find() is called.
60 //
61 //   bool Done() const final;
62 //
63 //   const Arc &Value() const final;
64 //
65 //   void Next() final;
66 //
67 //   // Returns final weight of a state.
68 //   Weight Final(StateId) const final;
69 //
70 //   // Indicates preference for being the side used for matching in
71 //   // composition. If the value is kRequirePriority, then it is
72 //   // mandatory that it be used. Calling this method without passing the
73 //   // current state of the matcher invalidates the state of the matcher.
74 //   ssize_t Priority(StateId s) final;
75 //
76 //   // This specifies the known FST properties as viewed from this matcher. It
77 //   // takes as argument the input FST's known properties.
78 //   uint64 Properties(uint64 props) const override;
79 //
80 //   // Returns matcher flags.
81 //   uint32 Flags() const override;
82 //
83 //   // Returns matcher FST.
84 //   const FST &GetFst() const override;
85 // };
86
87 // Basic matcher flags.
88
89 // Matcher needs to be used as the matching side in composition for
90 // at least one state (has kRequirePriority).
91 constexpr uint32 kRequireMatch = 0x00000001;
92
93 // Flags used for basic matchers (see also lookahead.h).
94 constexpr uint32 kMatcherFlags = kRequireMatch;
95
96 // Matcher priority that is mandatory.
97 constexpr ssize_t kRequirePriority = -1;
98
99 // Matcher interface, templated on the Arc definition; used for matcher
100 // specializations that are returned by the InitMatcher FST method.
101 template <class A>
102 class MatcherBase {
103  public:
104   using Arc = A;
105   using Label = typename Arc::Label;
106   using StateId = typename Arc::StateId;
107   using Weight = typename Arc::Weight;
108
109   virtual ~MatcherBase() {}
110
111   // Virtual interface.
112
113   virtual MatcherBase<Arc> *Copy(bool safe = false) const = 0;
114   virtual MatchType Type(bool) const = 0;
115   virtual void SetState(StateId) = 0;
116   virtual bool Find(Label) = 0;
117   virtual bool Done() const = 0;
118   virtual const Arc &Value() const = 0;
119   virtual void Next() = 0;
120   virtual const Fst<Arc> &GetFst() const = 0;
121   virtual uint64 Properties(uint64) const = 0;
122
123   // Trivial implementations that can be used by derived classes. Full
124   // devirtualization is expected for any derived class marked final.
125   virtual uint32 Flags() const { return 0; }
126
127   virtual Weight Final(StateId s) const { return internal::Final(GetFst(), s); }
128
129   virtual ssize_t Priority(StateId s) { return internal::NumArcs(GetFst(), s); }
130 };
131
132 // A matcher that expects sorted labels on the side to be matched.
133 // If match_type == MATCH_INPUT, epsilons match the implicit self-loop
134 // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
135 // actual epsilon transitions. If match_type == MATCH_OUTPUT, then
136 // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
137 template <class F>
138 class SortedMatcher : public MatcherBase<typename F::Arc> {
139  public:
140   using FST = F;
141   using Arc = typename FST::Arc;
142   using Label = typename Arc::Label;
143   using StateId = typename Arc::StateId;
144   using Weight = typename Arc::Weight;
145
146   using MatcherBase<Arc>::Flags;
147   using MatcherBase<Arc>::Properties;
148
149   // Labels >= binary_label will be searched for by binary search;
150   // o.w. linear search is used.
151   SortedMatcher(const FST &fst, MatchType match_type, Label binary_label = 1)
152       : fst_(fst.Copy()),
153         state_(kNoStateId),
154         aiter_(nullptr),
155         match_type_(match_type),
156         binary_label_(binary_label),
157         match_label_(kNoLabel),
158         narcs_(0),
159         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
160         error_(false),
161         aiter_pool_(1) {
162     switch (match_type_) {
163       case MATCH_INPUT:
164       case MATCH_NONE:
165         break;
166       case MATCH_OUTPUT:
167         std::swap(loop_.ilabel, loop_.olabel);
168         break;
169       default:
170         FSTERROR() << "SortedMatcher: Bad match type";
171         match_type_ = MATCH_NONE;
172         error_ = true;
173     }
174   }
175
176   SortedMatcher(const SortedMatcher<FST> &matcher, bool safe = false)
177       : fst_(matcher.fst_->Copy(safe)),
178         state_(kNoStateId),
179         aiter_(nullptr),
180         match_type_(matcher.match_type_),
181         binary_label_(matcher.binary_label_),
182         match_label_(kNoLabel),
183         narcs_(0),
184         loop_(matcher.loop_),
185         error_(matcher.error_),
186         aiter_pool_(1) {}
187
188   ~SortedMatcher() override { Destroy(aiter_, &aiter_pool_); }
189
190   SortedMatcher<FST> *Copy(bool safe = false) const override {
191     return new SortedMatcher<FST>(*this, safe);
192   }
193
194   MatchType Type(bool test) const override {
195     if (match_type_ == MATCH_NONE) return match_type_;
196     const auto true_prop =
197         match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
198     const auto false_prop =
199         match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
200     const auto props = fst_->Properties(true_prop | false_prop, test);
201     if (props & true_prop) {
202       return match_type_;
203     } else if (props & false_prop) {
204       return MATCH_NONE;
205     } else {
206       return MATCH_UNKNOWN;
207     }
208   }
209
210   void SetState(StateId s) final {
211     if (state_ == s) return;
212     state_ = s;
213     if (match_type_ == MATCH_NONE) {
214       FSTERROR() << "SortedMatcher: Bad match type";
215       error_ = true;
216     }
217     Destroy(aiter_, &aiter_pool_);
218     aiter_ = new (&aiter_pool_) ArcIterator<FST>(*fst_, s);
219     aiter_->SetFlags(kArcNoCache, kArcNoCache);
220     narcs_ = internal::NumArcs(*fst_, s);
221     loop_.nextstate = s;
222   }
223
224   bool Find(Label match_label) final {
225     exact_match_ = true;
226     if (error_) {
227       current_loop_ = false;
228       match_label_ = kNoLabel;
229       return false;
230     }
231     current_loop_ = match_label == 0;
232     match_label_ = match_label == kNoLabel ? 0 : match_label;
233     if (Search()) {
234       return true;
235     } else {
236       return current_loop_;
237     }
238   }
239
240   // Positions matcher to the first position where inserting match_label would
241   // maintain the sort order.
242   void LowerBound(Label label) {
243     exact_match_ = false;
244     current_loop_ = false;
245     if (error_) {
246       match_label_ = kNoLabel;
247       return;
248     }
249     match_label_ = label;
250     Search();
251   }
252
253   // After Find(), returns false if no more exact matches.
254   // After LowerBound(), returns false if no more arcs.
255   bool Done() const final {
256     if (current_loop_) return false;
257     if (aiter_->Done()) return true;
258     if (!exact_match_) return false;
259     aiter_->SetFlags(
260         match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
261         kArcValueFlags);
262     const auto label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel
263                                                   : aiter_->Value().olabel;
264     return label != match_label_;
265   }
266
267   const Arc &Value() const final {
268     if (current_loop_) return loop_;
269     aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
270     return aiter_->Value();
271   }
272
273   void Next() final {
274     if (current_loop_) {
275       current_loop_ = false;
276     } else {
277       aiter_->Next();
278     }
279   }
280
281   Weight Final(StateId s) const final {
282     return MatcherBase<Arc>::Final(s);
283   }
284
285   ssize_t Priority(StateId s) final {
286     return MatcherBase<Arc>::Priority(s);
287   }
288
289   const FST &GetFst() const override { return *fst_; }
290
291   uint64 Properties(uint64 inprops) const override {
292     return inprops | (error_ ? kError : 0);
293   }
294
295   size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
296
297  private:
298   Label GetLabel() const {
299     const auto &arc = aiter_->Value();
300     return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel;
301   }
302
303   bool BinarySearch();
304   bool LinearSearch();
305   bool Search();
306
307   std::unique_ptr<const FST> fst_;
308   StateId state_;            // Matcher state.
309   ArcIterator<FST> *aiter_;  // Iterator for current state.
310   MatchType match_type_;     // Type of match to perform.
311   Label binary_label_;       // Least label for binary search.
312   Label match_label_;        // Current label to be matched.
313   size_t narcs_;             // Current state arc count.
314   Arc loop_;                 // For non-consuming symbols.
315   bool current_loop_;        // Current arc is the implicit loop.
316   bool exact_match_;         // Exact match or lower bound?
317   bool error_;               // Error encountered?
318   MemoryPool<ArcIterator<FST>> aiter_pool_;  // Pool of arc iterators.
319 };
320
321 // Returns true iff match to match_label_. The arc iterator is positioned at the
322 // lower bound, that is, the first element greater than or equal to
323 // match_label_, or the end if all elements are less than match_label_.
324 template <class FST>
325 inline bool SortedMatcher<FST>::BinarySearch() {
326   size_t low = 0;
327   size_t high = narcs_;
328   while (low < high) {
329     const size_t mid = low + (high - low) / 2;
330     aiter_->Seek(mid);
331     if (GetLabel() < match_label_) {
332       low = mid + 1;
333     } else {
334       high = mid;
335     }
336   }
337
338   aiter_->Seek(low);
339   return low < narcs_ && GetLabel() == match_label_;
340 }
341
342 // Returns true iff match to match_label_, positioning arc iterator at lower
343 // bound.
344 template <class FST>
345 inline bool SortedMatcher<FST>::LinearSearch() {
346   for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
347     const auto label = GetLabel();
348     if (label == match_label_) return true;
349     if (label > match_label_) break;
350   }
351   return false;
352 }
353
354 // Returns true iff match to match_label_, positioning arc iterator at lower
355 // bound.
356 template <class FST>
357 inline bool SortedMatcher<FST>::Search() {
358   aiter_->SetFlags(match_type_ == MATCH_INPUT ?
359                    kArcILabelValue : kArcOLabelValue,
360                    kArcValueFlags);
361   if (match_label_ >= binary_label_) {
362     return BinarySearch();
363   } else {
364     return LinearSearch();
365   }
366 }
367
368 // A matcher that stores labels in a per-state hash table populated upon the
369 // first visit to that state. Sorting is not required. Treatment of
370 // epsilons are the same as with SortedMatcher.
371 template <class F>
372 class HashMatcher : public MatcherBase<typename F::Arc> {
373  public:
374   using FST = F;
375   using Arc = typename FST::Arc;
376   using Label = typename Arc::Label;
377   using StateId = typename Arc::StateId;
378   using Weight = typename Arc::Weight;
379
380   using MatcherBase<Arc>::Flags;
381   using MatcherBase<Arc>::Final;
382   using MatcherBase<Arc>::Priority;
383
384   HashMatcher(const FST &fst, MatchType match_type)
385       : fst_(fst.Copy()),
386         state_(kNoStateId),
387         match_type_(match_type),
388         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
389     switch (match_type_) {
390       case MATCH_INPUT:
391       case MATCH_NONE:
392         break;
393       case MATCH_OUTPUT:
394         std::swap(loop_.ilabel, loop_.olabel);
395         break;
396       default:
397         FSTERROR() << "SortedMatcher: Bad match type";
398         match_type_ = MATCH_NONE;
399         error_ = true;
400     }
401   }
402
403   HashMatcher(const HashMatcher<FST> &matcher, bool safe = false)
404       : fst_(matcher.fst_->Copy(safe)),
405         state_(kNoStateId),
406         match_type_(matcher.match_type_),
407         loop_(matcher.loop_),
408         error_(matcher.error_) {}
409
410   HashMatcher<FST> *Copy(bool safe = false) const override {
411     return new HashMatcher<FST>(*this, safe);
412   }
413
414   // The argument is ignored as there are no relevant properties to test.
415   MatchType Type(bool test) const override { return match_type_; }
416
417   void SetState(StateId s) final;
418
419   bool Find(Label label) final {
420     current_loop_ = label == 0;
421     if (label == 0) {
422       Search(label);
423       return true;
424     }
425     if (label == kNoLabel) label = 0;
426     return Search(label);
427   }
428
429   bool Done() const final {
430     if (current_loop_) return false;
431     return label_it_ == label_end_;
432   }
433
434   const Arc &Value() const final {
435     if (current_loop_) return loop_;
436     aiter_->Seek(label_it_->second);
437     return aiter_->Value();
438   }
439
440   void Next() final {
441     if (current_loop_) {
442       current_loop_ = false;
443     } else {
444       ++label_it_;
445     }
446   }
447
448   const FST &GetFst() const override { return *fst_; }
449
450   uint64 Properties(uint64 inprops) const override {
451     return inprops | (error_ ? kError : 0);
452   }
453
454  private:
455   bool Search(Label match_label);
456
457   using LabelTable = std::unordered_multimap<Label, size_t>;
458   using StateTable = std::unordered_map<StateId, LabelTable>;
459
460   std::unique_ptr<const FST> fst_;
461   StateId state_;  // Matcher state.
462   MatchType match_type_;
463   Arc loop_;           // The implicit loop itself.
464   bool current_loop_;  // Is the current arc is the implicit loop?
465   bool error_;         // Error encountered?
466   std::unique_ptr<ArcIterator<FST>> aiter_;
467   StateTable state_table_;   // Table from states to label table.
468   LabelTable *label_table_;  // Pointer to current state's label table.
469   typename LabelTable::iterator label_it_;   // Position for label.
470   typename LabelTable::iterator label_end_;  // Position for last label + 1.
471 };
472
473 template <class FST>
474 void HashMatcher<FST>::SetState(typename FST::Arc::StateId s) {
475   if (state_ == s) return;
476   // Resets everything for the state.
477   state_ = s;
478   loop_.nextstate = state_;
479   aiter_.reset(new ArcIterator<FST>(*fst_, state_));
480   if (match_type_ == MATCH_NONE) {
481     FSTERROR() << "HashMatcher: Bad match type";
482     error_ = false;
483   }
484   // Attempts to insert a new label table; if it already exists,
485   // no additional work is done and we simply return.
486   auto it_and_success = state_table_.emplace(state_, LabelTable());
487   if (!it_and_success.second) return;
488   // Otherwise, populate this new table.
489   // Sets instance's pointer to the label table for this state.
490   label_table_ = &(it_and_success.first->second);
491   // Populates the label table.
492   label_table_->reserve(internal::NumArcs(*fst_, state_));
493   const auto aiter_flags =
494       (match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue) |
495       kArcNoCache;
496   aiter_->SetFlags(aiter_flags, kArcFlags);
497   for (; !aiter_->Done(); aiter_->Next()) {
498     const auto label = (match_type_ == MATCH_INPUT) ? aiter_->Value().ilabel
499                                                     : aiter_->Value().olabel;
500     label_table_->emplace(label, aiter_->Position());
501   }
502   aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
503 }
504
505 template <class FST>
506 inline bool HashMatcher<FST>::Search(typename FST::Arc::Label match_label) {
507   auto range = label_table_->equal_range(match_label);
508   if (range.first == range.second) return false;
509   label_it_ = range.first;
510   label_end_ = range.second;
511   aiter_->Seek(label_it_->second);
512   return true;
513 }
514
515 // Specifies whether we rewrite both the input and output sides during matching.
516 enum MatcherRewriteMode {
517   MATCHER_REWRITE_AUTO = 0,  // Rewrites both sides iff acceptor.
518   MATCHER_REWRITE_ALWAYS,
519   MATCHER_REWRITE_NEVER
520 };
521
522 // For any requested label that doesn't match at a state, this matcher
523 // considers the *unique* transition that matches the label 'phi_label'
524 // (phi = 'fail'), and recursively looks for a match at its
525 // destination.  When 'phi_loop' is true, if no match is found but a
526 // phi self-loop is found, then the phi transition found is returned
527 // with the phi_label rewritten as the requested label (both sides if
528 // an acceptor, or if 'rewrite_both' is true and both input and output
529 // labels of the found transition are 'phi_label').  If 'phi_label' is
530 // kNoLabel, this special matching is not done.  PhiMatcher is
531 // templated itself on a matcher, which is used to perform the
532 // underlying matching.  By default, the underlying matcher is
533 // constructed by PhiMatcher. The user can instead pass in this
534 // object; in that case, PhiMatcher takes its ownership.
535 // Phi non-determinism not supported. No non-consuming symbols other
536 // than epsilon supported with the underlying template argument matcher.
537 template <class M>
538 class PhiMatcher : public MatcherBase<typename M::Arc> {
539  public:
540   using FST = typename M::FST;
541   using Arc = typename FST::Arc;
542   using Label = typename Arc::Label;
543   using StateId = typename Arc::StateId;
544   using Weight = typename Arc::Weight;
545
546   PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel,
547              bool phi_loop = true,
548              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
549              M *matcher = nullptr)
550       : matcher_(matcher ? matcher : new M(fst, match_type)),
551         match_type_(match_type),
552         phi_label_(phi_label),
553         state_(kNoStateId),
554         phi_loop_(phi_loop),
555         error_(false) {
556     if (match_type == MATCH_BOTH) {
557       FSTERROR() << "PhiMatcher: Bad match type";
558       match_type_ = MATCH_NONE;
559       error_ = true;
560     }
561     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
562       rewrite_both_ = fst.Properties(kAcceptor, true);
563     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
564       rewrite_both_ = true;
565     } else {
566       rewrite_both_ = false;
567     }
568   }
569
570   PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false)
571       : matcher_(new M(*matcher.matcher_, safe)),
572         match_type_(matcher.match_type_),
573         phi_label_(matcher.phi_label_),
574         rewrite_both_(matcher.rewrite_both_),
575         state_(kNoStateId),
576         phi_loop_(matcher.phi_loop_),
577         error_(matcher.error_) {}
578
579   PhiMatcher<M> *Copy(bool safe = false) const override {
580     return new PhiMatcher<M>(*this, safe);
581   }
582
583   MatchType Type(bool test) const override { return matcher_->Type(test); }
584
585   void SetState(StateId s) final {
586     if (state_ == s) return;
587     matcher_->SetState(s);
588     state_ = s;
589     has_phi_ = phi_label_ != kNoLabel;
590   }
591
592   bool Find(Label match_label) final;
593
594   bool Done() const final { return matcher_->Done(); }
595
596   const Arc &Value() const final {
597     if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
598       return matcher_->Value();
599     } else if (phi_match_ == 0) {  // Virtual epsilon loop.
600       phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
601       if (match_type_ == MATCH_OUTPUT) {
602         std::swap(phi_arc_.ilabel, phi_arc_.olabel);
603       }
604       return phi_arc_;
605     } else {
606       phi_arc_ = matcher_->Value();
607       phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
608       if (phi_match_ != kNoLabel) {  // Phi loop match.
609         if (rewrite_both_) {
610           if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_;
611           if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_;
612         } else if (match_type_ == MATCH_INPUT) {
613           phi_arc_.ilabel = phi_match_;
614         } else {
615           phi_arc_.olabel = phi_match_;
616         }
617       }
618       return phi_arc_;
619     }
620   }
621
622   void Next() final { matcher_->Next(); }
623
624   Weight Final(StateId s) const final {
625     auto weight = matcher_->Final(s);
626     if (phi_label_ == kNoLabel || weight != Weight::Zero()) {
627       return weight;
628     }
629     weight = Weight::One();
630     matcher_->SetState(s);
631     while (matcher_->Final(s) == Weight::Zero()) {
632       if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) break;
633       weight = Times(weight, matcher_->Value().weight);
634       if (s == matcher_->Value().nextstate) {
635         return Weight::Zero();  // Does not follow phi self-loops.
636       }
637       s = matcher_->Value().nextstate;
638       matcher_->SetState(s);
639     }
640     weight = Times(weight, matcher_->Final(s));
641     return weight;
642   }
643
644   ssize_t Priority(StateId s) final {
645     if (phi_label_ != kNoLabel) {
646       matcher_->SetState(s);
647       const bool has_phi = matcher_->Find(phi_label_ == 0 ? -1 : phi_label_);
648       return has_phi ? kRequirePriority : matcher_->Priority(s);
649     } else {
650       return matcher_->Priority(s);
651     }
652   }
653
654   const FST &GetFst() const override { return matcher_->GetFst(); }
655
656   uint64 Properties(uint64 props) const override;
657
658   uint32 Flags() const override {
659     if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) {
660       return matcher_->Flags();
661     }
662     return matcher_->Flags() | kRequireMatch;
663   }
664
665   Label PhiLabel() const { return phi_label_; }
666
667  private:
668   mutable std::unique_ptr<M> matcher_;
669   MatchType match_type_;  // Type of match requested.
670   Label phi_label_;       // Label that represents the phi transition.
671   bool rewrite_both_;     // Rewrite both sides when both are phi_label_?
672   bool has_phi_;          // Are there possibly phis at the current state?
673   Label phi_match_;       // Current label that matches phi loop.
674   mutable Arc phi_arc_;   // Arc to return.
675   StateId state_;         // Matcher state.
676   Weight phi_weight_;     // Product of the weights of phi transitions taken.
677   bool phi_loop_;         // When true, phi self-loop are allowed and treated
678                           // as rho (required for Aho-Corasick).
679   bool error_;            // Error encountered?
680
681   PhiMatcher &operator=(const PhiMatcher &) = delete;
682 };
683
684 template <class M>
685 inline bool PhiMatcher<M>::Find(Label label) {
686   if (label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
687     FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
688     error_ = true;
689     return false;
690   }
691   matcher_->SetState(state_);
692   phi_match_ = kNoLabel;
693   phi_weight_ = Weight::One();
694   // If phi_label_ == 0, there are no more true epsilon arcs.
695   if (phi_label_ == 0) {
696     if (label == kNoLabel) {
697       return false;
698     }
699     if (label == 0) {  // but a virtual epsilon loop needs to be returned.
700       if (!matcher_->Find(kNoLabel)) {
701         return matcher_->Find(0);
702       } else {
703         phi_match_ = 0;
704         return true;
705       }
706     }
707   }
708   if (!has_phi_ || label == 0 || label == kNoLabel) {
709     return matcher_->Find(label);
710   }
711   auto s = state_;
712   while (!matcher_->Find(label)) {
713     // Look for phi transition (if phi_label_ == 0, we need to look
714     // for -1 to avoid getting the virtual self-loop)
715     if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false;
716     if (phi_loop_ && matcher_->Value().nextstate == s) {
717       phi_match_ = label;
718       return true;
719     }
720     phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
721     s = matcher_->Value().nextstate;
722     matcher_->Next();
723     if (!matcher_->Done()) {
724       FSTERROR() << "PhiMatcher: Phi non-determinism not supported";
725       error_ = true;
726     }
727     matcher_->SetState(s);
728   }
729   return true;
730 }
731
732 template <class M>
733 inline uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
734   auto outprops = matcher_->Properties(inprops);
735   if (error_) outprops |= kError;
736   if (match_type_ == MATCH_NONE) {
737     return outprops;
738   } else if (match_type_ == MATCH_INPUT) {
739     if (phi_label_ == 0) {
740       outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
741       outprops |= kNoEpsilons | kNoIEpsilons;
742     }
743     if (rewrite_both_) {
744       return outprops &
745              ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
746                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
747     } else {
748       return outprops &
749              ~(kODeterministic | kAcceptor | kString | kILabelSorted |
750                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
751     }
752   } else if (match_type_ == MATCH_OUTPUT) {
753     if (phi_label_ == 0) {
754       outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
755       outprops |= kNoEpsilons | kNoOEpsilons;
756     }
757     if (rewrite_both_) {
758       return outprops &
759              ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
760                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
761     } else {
762       return outprops &
763              ~(kIDeterministic | kAcceptor | kString | kILabelSorted |
764                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
765     }
766   } else {
767     // Shouldn't ever get here.
768     FSTERROR() << "PhiMatcher: Bad match type: " << match_type_;
769     return 0;
770   }
771 }
772
773 // For any requested label that doesn't match at a state, this matcher
774 // considers all transitions that match the label 'rho_label' (rho =
775 // 'rest').  Each such rho transition found is returned with the
776 // rho_label rewritten as the requested label (both sides if an
777 // acceptor, or if 'rewrite_both' is true and both input and output
778 // labels of the found transition are 'rho_label').  If 'rho_label' is
779 // kNoLabel, this special matching is not done.  RhoMatcher is
780 // templated itself on a matcher, which is used to perform the
781 // underlying matching.  By default, the underlying matcher is
782 // constructed by RhoMatcher.  The user can instead pass in this
783 // object; in that case, RhoMatcher takes its ownership.
784 // No non-consuming symbols other than epsilon supported with
785 // the underlying template argument matcher.
786 template <class M>
787 class RhoMatcher : public MatcherBase<typename M::Arc> {
788  public:
789   using FST = typename M::FST;
790   using Arc = typename FST::Arc;
791   using Label = typename Arc::Label;
792   using StateId = typename Arc::StateId;
793   using Weight = typename Arc::Weight;
794
795   RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel,
796              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
797              M *matcher = nullptr)
798       : matcher_(matcher ? matcher : new M(fst, match_type)),
799         match_type_(match_type),
800         rho_label_(rho_label),
801         error_(false),
802         state_(kNoStateId) {
803     if (match_type == MATCH_BOTH) {
804       FSTERROR() << "RhoMatcher: Bad match type";
805       match_type_ = MATCH_NONE;
806       error_ = true;
807     }
808     if (rho_label == 0) {
809       FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
810       rho_label_ = kNoLabel;
811       error_ = true;
812     }
813     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
814       rewrite_both_ = fst.Properties(kAcceptor, true);
815     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
816       rewrite_both_ = true;
817     } else {
818       rewrite_both_ = false;
819     }
820   }
821
822   RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false)
823       : matcher_(new M(*matcher.matcher_, safe)),
824         match_type_(matcher.match_type_),
825         rho_label_(matcher.rho_label_),
826         rewrite_both_(matcher.rewrite_both_),
827         error_(matcher.error_),
828         state_(kNoStateId) {}
829
830   RhoMatcher<M> *Copy(bool safe = false) const override {
831     return new RhoMatcher<M>(*this, safe);
832   }
833
834   MatchType Type(bool test) const override { return matcher_->Type(test); }
835
836   void SetState(StateId s) final {
837     if (state_ == s) return;
838     state_ = s;
839     matcher_->SetState(s);
840     has_rho_ = rho_label_ != kNoLabel;
841   }
842
843   bool Find(Label label) final {
844     if (label == rho_label_ && rho_label_ != kNoLabel) {
845       FSTERROR() << "RhoMatcher::Find: bad label (rho)";
846       error_ = true;
847       return false;
848     }
849     if (matcher_->Find(label)) {
850       rho_match_ = kNoLabel;
851       return true;
852     } else if (has_rho_ && label != 0 && label != kNoLabel &&
853                (has_rho_ = matcher_->Find(rho_label_))) {
854       rho_match_ = label;
855       return true;
856     } else {
857       return false;
858     }
859   }
860
861   bool Done() const final { return matcher_->Done(); }
862
863   const Arc &Value() const final {
864     if (rho_match_ == kNoLabel) {
865       return matcher_->Value();
866     } else {
867       rho_arc_ = matcher_->Value();
868       if (rewrite_both_) {
869         if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_;
870         if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_;
871       } else if (match_type_ == MATCH_INPUT) {
872         rho_arc_.ilabel = rho_match_;
873       } else {
874         rho_arc_.olabel = rho_match_;
875       }
876       return rho_arc_;
877     }
878   }
879
880   void Next() final { matcher_->Next(); }
881
882   Weight Final(StateId s) const final { return matcher_->Final(s); }
883
884   ssize_t Priority(StateId s) final {
885     state_ = s;
886     matcher_->SetState(s);
887     has_rho_ = matcher_->Find(rho_label_);
888     if (has_rho_) {
889       return kRequirePriority;
890     } else {
891       return matcher_->Priority(s);
892     }
893   }
894
895   const FST &GetFst() const override { return matcher_->GetFst(); }
896
897   uint64 Properties(uint64 props) const override;
898
899   uint32 Flags() const override {
900     if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) {
901       return matcher_->Flags();
902     }
903     return matcher_->Flags() | kRequireMatch;
904   }
905
906   Label RhoLabel() const { return rho_label_; }
907
908  private:
909   std::unique_ptr<M> matcher_;
910   MatchType match_type_;  // Type of match requested.
911   Label rho_label_;       // Label that represents the rho transition
912   bool rewrite_both_;     // Rewrite both sides when both are rho_label_?
913   bool has_rho_;          // Are there possibly rhos at the current state?
914   Label rho_match_;       // Current label that matches rho transition.
915   mutable Arc rho_arc_;   // Arc to return when rho match.
916   bool error_;            // Error encountered?
917   StateId state_;         // Matcher state.
918 };
919
920 template <class M>
921 inline uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
922   auto outprops = matcher_->Properties(inprops);
923   if (error_) outprops |= kError;
924   if (match_type_ == MATCH_NONE) {
925     return outprops;
926   } else if (match_type_ == MATCH_INPUT) {
927     if (rewrite_both_) {
928       return outprops &
929              ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
930                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
931     } else {
932       return outprops &
933              ~(kODeterministic | kAcceptor | kString | kILabelSorted |
934                kNotILabelSorted);
935     }
936   } else if (match_type_ == MATCH_OUTPUT) {
937     if (rewrite_both_) {
938       return outprops &
939              ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
940                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
941     } else {
942       return outprops &
943              ~(kIDeterministic | kAcceptor | kString | kOLabelSorted |
944                kNotOLabelSorted);
945     }
946   } else {
947     // Shouldn't ever get here.
948     FSTERROR() << "RhoMatcher: Bad match type: " << match_type_;
949     return 0;
950   }
951 }
952
953 // For any requested label, this matcher considers all transitions
954 // that match the label 'sigma_label' (sigma = "any"), and this in
955 // additions to transitions with the requested label.  Each such sigma
956 // transition found is returned with the sigma_label rewritten as the
957 // requested label (both sides if an acceptor, or if 'rewrite_both' is
958 // true and both input and output labels of the found transition are
959 // 'sigma_label').  If 'sigma_label' is kNoLabel, this special
960 // matching is not done.  SigmaMatcher is templated itself on a
961 // matcher, which is used to perform the underlying matching.  By
962 // default, the underlying matcher is constructed by SigmaMatcher.
963 // The user can instead pass in this object; in that case,
964 // SigmaMatcher takes its ownership.  No non-consuming symbols other
965 // than epsilon supported with the underlying template argument matcher.
966 template <class M>
967 class SigmaMatcher : public MatcherBase<typename M::Arc> {
968  public:
969   using FST = typename M::FST;
970   using Arc = typename FST::Arc;
971   using Label = typename Arc::Label;
972   using StateId = typename Arc::StateId;
973   using Weight = typename Arc::Weight;
974
975   SigmaMatcher(const FST &fst, MatchType match_type,
976                Label sigma_label = kNoLabel,
977                MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
978                M *matcher = nullptr)
979       : matcher_(matcher ? matcher : new M(fst, match_type)),
980         match_type_(match_type),
981         sigma_label_(sigma_label),
982         error_(false),
983         state_(kNoStateId) {
984     if (match_type == MATCH_BOTH) {
985       FSTERROR() << "SigmaMatcher: Bad match type";
986       match_type_ = MATCH_NONE;
987       error_ = true;
988     }
989     if (sigma_label == 0) {
990       FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
991       sigma_label_ = kNoLabel;
992       error_ = true;
993     }
994     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
995       rewrite_both_ = fst.Properties(kAcceptor, true);
996     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
997       rewrite_both_ = true;
998     } else {
999       rewrite_both_ = false;
1000     }
1001   }
1002
1003   SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false)
1004       : matcher_(new M(*matcher.matcher_, safe)),
1005         match_type_(matcher.match_type_),
1006         sigma_label_(matcher.sigma_label_),
1007         rewrite_both_(matcher.rewrite_both_),
1008         error_(matcher.error_),
1009         state_(kNoStateId) {}
1010
1011   SigmaMatcher<M> *Copy(bool safe = false) const override {
1012     return new SigmaMatcher<M>(*this, safe);
1013   }
1014
1015   MatchType Type(bool test) const override { return matcher_->Type(test); }
1016
1017   void SetState(StateId s) final {
1018     if (state_ == s) return;
1019     state_ = s;
1020     matcher_->SetState(s);
1021     has_sigma_ =
1022         (sigma_label_ != kNoLabel) ? matcher_->Find(sigma_label_) : false;
1023   }
1024
1025   bool Find(Label match_label) final {
1026     match_label_ = match_label;
1027     if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
1028       FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
1029       error_ = true;
1030       return false;
1031     }
1032     if (matcher_->Find(match_label)) {
1033       sigma_match_ = kNoLabel;
1034       return true;
1035     } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
1036                matcher_->Find(sigma_label_)) {
1037       sigma_match_ = match_label;
1038       return true;
1039     } else {
1040       return false;
1041     }
1042   }
1043
1044   bool Done() const final { return matcher_->Done(); }
1045
1046   const Arc &Value() const final {
1047     if (sigma_match_ == kNoLabel) {
1048       return matcher_->Value();
1049     } else {
1050       sigma_arc_ = matcher_->Value();
1051       if (rewrite_both_) {
1052         if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_;
1053         if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_;
1054       } else if (match_type_ == MATCH_INPUT) {
1055         sigma_arc_.ilabel = sigma_match_;
1056       } else {
1057         sigma_arc_.olabel = sigma_match_;
1058       }
1059       return sigma_arc_;
1060     }
1061   }
1062
1063   void Next() final {
1064     matcher_->Next();
1065     if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
1066         (match_label_ > 0)) {
1067       matcher_->Find(sigma_label_);
1068       sigma_match_ = match_label_;
1069     }
1070   }
1071
1072   Weight Final(StateId s) const final { return matcher_->Final(s); }
1073
1074   ssize_t Priority(StateId s) final {
1075     if (sigma_label_ != kNoLabel) {
1076       SetState(s);
1077       return has_sigma_ ? kRequirePriority : matcher_->Priority(s);
1078     } else {
1079       return matcher_->Priority(s);
1080     }
1081   }
1082
1083   const FST &GetFst() const override { return matcher_->GetFst(); }
1084
1085   uint64 Properties(uint64 props) const override;
1086
1087   uint32 Flags() const override {
1088     if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) {
1089       return matcher_->Flags();
1090     }
1091     return matcher_->Flags() | kRequireMatch;
1092   }
1093
1094   Label SigmaLabel() const { return sigma_label_; }
1095
1096  private:
1097   std::unique_ptr<M> matcher_;
1098   MatchType match_type_;   // Type of match requested.
1099   Label sigma_label_;      // Label that represents the sigma transition.
1100   bool rewrite_both_;      // Rewrite both sides when both are sigma_label_?
1101   bool has_sigma_;         // Are there sigmas at the current state?
1102   Label sigma_match_;      // Current label that matches sigma transition.
1103   mutable Arc sigma_arc_;  // Arc to return when sigma match.
1104   Label match_label_;      // Label being matched.
1105   bool error_;             // Error encountered?
1106   StateId state_;          // Matcher state.
1107 };
1108
1109 template <class M>
1110 inline uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
1111   auto outprops = matcher_->Properties(inprops);
1112   if (error_) outprops |= kError;
1113   if (match_type_ == MATCH_NONE) {
1114     return outprops;
1115   } else if (rewrite_both_) {
1116     return outprops &
1117            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1118              kNonODeterministic | kILabelSorted | kNotILabelSorted |
1119              kOLabelSorted | kNotOLabelSorted | kString);
1120   } else if (match_type_ == MATCH_INPUT) {
1121     return outprops &
1122            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1123              kNonODeterministic | kILabelSorted | kNotILabelSorted | kString |
1124              kAcceptor);
1125   } else if (match_type_ == MATCH_OUTPUT) {
1126     return outprops &
1127            ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1128              kNonODeterministic | kOLabelSorted | kNotOLabelSorted | kString |
1129              kAcceptor);
1130   } else {
1131     // Shouldn't ever get here.
1132     FSTERROR() << "SigmaMatcher: Bad match type: " << match_type_;
1133     return 0;
1134   }
1135 }
1136
1137 // Flags for MultiEpsMatcher.
1138
1139 // Return multi-epsilon arcs for Find(kNoLabel).
1140 const uint32 kMultiEpsList = 0x00000001;
1141
1142 // Return a kNolabel loop for Find(multi_eps).
1143 const uint32 kMultiEpsLoop = 0x00000002;
1144
1145 // MultiEpsMatcher: allows treating multiple non-0 labels as
1146 // non-consuming labels in addition to 0 that is always
1147 // non-consuming. Precise behavior controlled by 'flags' argument. By
1148 // default, the underlying matcher is constructed by
1149 // MultiEpsMatcher. The user can instead pass in this object; in that
1150 // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
1151 // true.
1152 template <class M>
1153 class MultiEpsMatcher {
1154  public:
1155   using FST = typename M::FST;
1156   using Arc = typename FST::Arc;
1157   using Label = typename Arc::Label;
1158   using StateId = typename Arc::StateId;
1159   using Weight = typename Arc::Weight;
1160
1161   MultiEpsMatcher(const FST &fst, MatchType match_type,
1162                   uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1163                   M *matcher = nullptr, bool own_matcher = true)
1164       : matcher_(matcher ? matcher : new M(fst, match_type)),
1165         flags_(flags),
1166         own_matcher_(matcher ? own_matcher : true) {
1167     if (match_type == MATCH_INPUT) {
1168       loop_.ilabel = kNoLabel;
1169       loop_.olabel = 0;
1170     } else {
1171       loop_.ilabel = 0;
1172       loop_.olabel = kNoLabel;
1173     }
1174     loop_.weight = Weight::One();
1175     loop_.nextstate = kNoStateId;
1176   }
1177
1178   MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false)
1179       : matcher_(new M(*matcher.matcher_, safe)),
1180         flags_(matcher.flags_),
1181         own_matcher_(true),
1182         multi_eps_labels_(matcher.multi_eps_labels_),
1183         loop_(matcher.loop_) {
1184     loop_.nextstate = kNoStateId;
1185   }
1186
1187   ~MultiEpsMatcher() {
1188     if (own_matcher_) delete matcher_;
1189   }
1190
1191   MultiEpsMatcher<M> *Copy(bool safe = false) const {
1192     return new MultiEpsMatcher<M>(*this, safe);
1193   }
1194
1195   MatchType Type(bool test) const { return matcher_->Type(test); }
1196
1197   void SetState(StateId state) {
1198     matcher_->SetState(state);
1199     loop_.nextstate = state;
1200   }
1201
1202   bool Find(Label label);
1203
1204   bool Done() const { return done_; }
1205
1206   const Arc &Value() const { return current_loop_ ? loop_ : matcher_->Value(); }
1207
1208   void Next() {
1209     if (!current_loop_) {
1210       matcher_->Next();
1211       done_ = matcher_->Done();
1212       if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
1213         ++multi_eps_iter_;
1214         while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1215                !matcher_->Find(*multi_eps_iter_)) {
1216           ++multi_eps_iter_;
1217         }
1218         if (multi_eps_iter_ != multi_eps_labels_.End()) {
1219           done_ = false;
1220         } else {
1221           done_ = !matcher_->Find(kNoLabel);
1222         }
1223       }
1224     } else {
1225       done_ = true;
1226     }
1227   }
1228
1229   const FST &GetFst() const { return matcher_->GetFst(); }
1230
1231   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
1232
1233   const M *GetMatcher() const { return matcher_; }
1234
1235   Weight Final(StateId s) const { return matcher_->Final(s); }
1236
1237   uint32 Flags() const { return matcher_->Flags(); }
1238
1239   ssize_t Priority(StateId s) { return matcher_->Priority(s); }
1240
1241   void AddMultiEpsLabel(Label label) {
1242     if (label == 0) {
1243       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1244     } else {
1245       multi_eps_labels_.Insert(label);
1246     }
1247   }
1248
1249   void RemoveMultiEpsLabel(Label label) {
1250     if (label == 0) {
1251       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1252     } else {
1253       multi_eps_labels_.Erase(label);
1254     }
1255   }
1256
1257   void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); }
1258
1259  private:
1260   M *matcher_;
1261   uint32 flags_;
1262   bool own_matcher_;  // Does this class delete the matcher?
1263
1264   // Multi-eps label set.
1265   CompactSet<Label, kNoLabel> multi_eps_labels_;
1266   typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1267
1268   bool current_loop_;  // Current arc is the implicit loop?
1269   mutable Arc loop_;   // For non-consuming symbols.
1270   bool done_;          // Matching done?
1271
1272   MultiEpsMatcher &operator=(const MultiEpsMatcher &) = delete;
1273 };
1274
1275 template <class M>
1276 inline bool MultiEpsMatcher<M>::Find(Label label) {
1277   multi_eps_iter_ = multi_eps_labels_.End();
1278   current_loop_ = false;
1279   bool ret;
1280   if (label == 0) {
1281     ret = matcher_->Find(0);
1282   } else if (label == kNoLabel) {
1283     if (flags_ & kMultiEpsList) {
1284       // Returns all non-consuming arcs (including epsilon).
1285       multi_eps_iter_ = multi_eps_labels_.Begin();
1286       while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1287              !matcher_->Find(*multi_eps_iter_)) {
1288         ++multi_eps_iter_;
1289       }
1290       if (multi_eps_iter_ != multi_eps_labels_.End()) {
1291         ret = true;
1292       } else {
1293         ret = matcher_->Find(kNoLabel);
1294       }
1295     } else {
1296       // Returns all epsilon arcs.
1297       ret = matcher_->Find(kNoLabel);
1298     }
1299   } else if ((flags_ & kMultiEpsLoop) &&
1300              multi_eps_labels_.Find(label) != multi_eps_labels_.End()) {
1301     // Returns implicit loop.
1302     current_loop_ = true;
1303     ret = true;
1304   } else {
1305     ret = matcher_->Find(label);
1306   }
1307   done_ = !ret;
1308   return ret;
1309 }
1310
1311 // This class discards any implicit matches (e.g., the implicit epsilon
1312 // self-loops in the SortedMatcher). Matchers are most often used in
1313 // composition/intersection where the implicit matches are needed
1314 // e.g. for epsilon processing. However, if a matcher is simply being
1315 // used to look-up explicit label matches, this class saves the user
1316 // from having to check for and discard the unwanted implicit matches
1317 // themselves.
1318 template <class M>
1319 class ExplicitMatcher : public MatcherBase<typename M::Arc> {
1320  public:
1321   using FST = typename M::FST;
1322   using Arc = typename FST::Arc;
1323   using Label = typename Arc::Label;
1324   using StateId = typename Arc::StateId;
1325   using Weight = typename Arc::Weight;
1326
1327   ExplicitMatcher(const FST &fst, MatchType match_type, M *matcher = nullptr)
1328       : matcher_(matcher ? matcher : new M(fst, match_type)),
1329         match_type_(match_type),
1330         error_(false) {}
1331
1332   ExplicitMatcher(const ExplicitMatcher<M> &matcher, bool safe = false)
1333       : matcher_(new M(*matcher.matcher_, safe)),
1334         match_type_(matcher.match_type_),
1335         error_(matcher.error_) {}
1336
1337   ExplicitMatcher<M> *Copy(bool safe = false) const override {
1338     return new ExplicitMatcher<M>(*this, safe);
1339   }
1340
1341   MatchType Type(bool test) const override { return matcher_->Type(test); }
1342
1343   void SetState(StateId s) final { matcher_->SetState(s); }
1344
1345   bool Find(Label label) final {
1346     matcher_->Find(label);
1347     CheckArc();
1348     return !Done();
1349   }
1350
1351   bool Done() const final { return matcher_->Done(); }
1352
1353   const Arc &Value() const final { return matcher_->Value(); }
1354
1355   void Next() final {
1356     matcher_->Next();
1357     CheckArc();
1358   }
1359
1360   Weight Final(StateId s) const final { return matcher_->Final(s); }
1361
1362   ssize_t Priority(StateId s) final { return matcher_->Priority(s); }
1363
1364   const FST &GetFst() const final { return matcher_->GetFst(); }
1365
1366   uint64 Properties(uint64 inprops) const override {
1367     return matcher_->Properties(inprops);
1368   }
1369
1370   const M *GetMatcher() const { return matcher_.get(); }
1371
1372   uint32 Flags() const override { return matcher_->Flags(); }
1373
1374  private:
1375   // Checks current arc if available and explicit. If not available, stops. If
1376   // not explicit, checks next ones.
1377   void CheckArc() {
1378     for (; !matcher_->Done(); matcher_->Next()) {
1379       const auto label = match_type_ == MATCH_INPUT ? matcher_->Value().ilabel
1380                                                     : matcher_->Value().olabel;
1381       if (label != kNoLabel) return;
1382     }
1383   }
1384
1385   std::unique_ptr<M> matcher_;
1386   MatchType match_type_;  // Type of match requested.
1387   bool error_;            // Error encountered?
1388 };
1389
1390 // Generic matcher, templated on the FST definition.
1391 //
1392 // Here is a typical use:
1393 //
1394 //   Matcher<StdFst> matcher(fst, MATCH_INPUT);
1395 //   matcher.SetState(state);
1396 //   if (matcher.Find(label))
1397 //     for (; !matcher.Done(); matcher.Next()) {
1398 //       auto &arc = matcher.Value();
1399 //       ...
1400 //     }
1401 template <class F>
1402 class Matcher {
1403  public:
1404   using FST = F;
1405   using Arc = typename F::Arc;
1406   using Label = typename Arc::Label;
1407   using StateId = typename Arc::StateId;
1408   using Weight = typename Arc::Weight;
1409
1410   Matcher(const FST &fst, MatchType match_type) {
1411     base_.reset(fst.InitMatcher(match_type));
1412     if (!base_) base_.reset(new SortedMatcher<FST>(fst, match_type));
1413   }
1414
1415   Matcher(const Matcher<FST> &matcher, bool safe = false) {
1416     base_.reset(matcher.base_->Copy(safe));
1417   }
1418
1419   // Takes ownership of the provided matcher.
1420   explicit Matcher(MatcherBase<Arc> *base_matcher) {
1421     base_.reset(base_matcher);
1422   }
1423
1424   Matcher<FST> *Copy(bool safe = false) const {
1425     return new Matcher<FST>(*this, safe);
1426   }
1427
1428   MatchType Type(bool test) const { return base_->Type(test); }
1429
1430   void SetState(StateId s) { base_->SetState(s); }
1431
1432   bool Find(Label label) { return base_->Find(label); }
1433
1434   bool Done() const { return base_->Done(); }
1435
1436   const Arc &Value() const { return base_->Value(); }
1437
1438   void Next() { base_->Next(); }
1439
1440   const FST &GetFst() const {
1441     return static_cast<const FST &>(base_->GetFst());
1442   }
1443
1444   uint64 Properties(uint64 props) const { return base_->Properties(props); }
1445
1446   Weight Final(StateId s) const { return base_->Final(s); }
1447
1448   uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1449
1450   ssize_t Priority(StateId s) { return base_->Priority(s); }
1451
1452  private:
1453   std::unique_ptr<MatcherBase<Arc>> base_;
1454 };
1455
1456 }  // namespace fst
1457
1458 #endif  // FST_MATCHER_H_