b95fbb23ffca371c544f67f14e9a70386e4cb5b8
[platform/upstream/openfst.git] / src / include / fst / lookahead-matcher.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to add lookahead to FST matchers, useful for improving composition
5 // efficiency with certain inputs.
6
7 #ifndef FST_LOOKAHEAD_MATCHER_H_
8 #define FST_LOOKAHEAD_MATCHER_H_
9
10 #include <memory>
11 #include <utility>
12 #include <vector>
13
14 #include <fst/flags.h>
15 #include <fst/log.h>
16
17 #include <fst/add-on.h>
18 #include <fst/const-fst.h>
19 #include <fst/fst.h>
20 #include <fst/label-reachable.h>
21 #include <fst/matcher.h>
22
23
24 DECLARE_string(save_relabel_ipairs);
25 DECLARE_string(save_relabel_opairs);
26
27 namespace fst {
28
29 // Lookahead matches extend the matcher interface with following additional
30 // methods:
31 //
32 // template <class FST>
33 // class LookAheadMatcher {
34 //  public:
35 //   using Arc = typename FST::Arc;
36 //   using Label = typename Arc::Label;
37 //   using StateId = typename Arc::StateId;
38 //   using Weight = typename Arc::Weight;
39 //
40 //  // Required constructors.
41 //  LookAheadMatcher(const FST &fst, MatchType match_type);
42 //   // If safe=true, the copy is thread-safe (except the lookahead FST is
43 //   // preserved). See Fst<>::Cop() for further doc.
44 //  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
45 //
46 //  // Below are methods for looking ahead for a match to a label and more
47 //  // generally, to a rational set. Each returns false if there is definitely
48 //  // not a match and returns true if there possibly is a match.
49 //
50 //  // Optionally pre-specifies the lookahead FST that will be passed to
51 //  // LookAheadFst() for possible precomputation. If copy is true, then the FST
52 //  // argument is a copy of the FST used in the previous call to this method
53 //  // (to avoid unnecessary updates).
54 //  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override;
55 //
56 //  // Are there paths from a state in the lookahead FST that can be read from
57 //  // the curent matcher state?
58 //  bool LookAheadFst(const Fst<Arc> &fst, StateId s) override;
59 //
60 //  // Can the label be read from the current matcher state after possibly
61 //  // following epsilon transitions?
62 //  bool LookAheadLabel(Label label) const override;
63 //
64 //  // The following methods allow looking ahead for an arbitrary rational set
65 //  // of strings, specified by an FST and a state from which to begin the
66 //  // matching. If the lookahead FST is a transducer, this looks on the side
67 //  // different from the matcher's match_type (cf. composition).
68 //  // Is there is a single non-epsilon arc found in the lookahead FST that
69 //  // begins the path (after possibly following any epsilons) in the last call
70 //  // to LookAheadFst? If so, return true and copy it to the arc argument;
71 //  // otherwise, return false. Non-trivial implementations are useful for
72 //  // label-pushing in composition.
73 //  bool LookAheadPrefix(Arc *arc) override;
74 //
75 //  // Gives an estimate of the combined weight of the paths in the lookahead
76 //  // and matcher FSTs for the last call to LookAheadFst. Non-trivial
77 //  // implementations are useful for weight-pushing in composition.
78 //  Weight LookAheadWeight() const override;
79 // };
80
81 // Look-ahead flags.
82 // Matcher is a lookahead matcher when match_type is MATCH_INPUT.
83 constexpr uint32 kInputLookAheadMatcher = 0x00000010;
84
85 // Matcher is a lookahead matcher when match_type is MATCH_OUTPUT.
86 constexpr uint32 kOutputLookAheadMatcher = 0x00000020;
87
88 // Is a non-trivial implementation of LookAheadWeight() method defined and
89 // if so, should it be used?
90 constexpr uint32 kLookAheadWeight = 0x00000040;
91
92 // Is a non-trivial implementation of LookAheadPrefix() method defined and
93 // if so, should it be used?
94 constexpr uint32 kLookAheadPrefix = 0x00000080;
95
96 // Look-ahead of matcher FST non-epsilon arcs?
97 constexpr uint32 kLookAheadNonEpsilons = 0x00000100;
98
99 // Look-ahead of matcher FST epsilon arcs?
100 constexpr uint32 kLookAheadEpsilons = 0x00000200;
101
102 // Ignore epsilon paths for the lookahead prefix? This gives correct results in
103 // composition only with an appropriate composition filter since it depends on
104 // the filter blocking the ignored paths.
105 constexpr uint32 kLookAheadNonEpsilonPrefix = 0x00000400;
106
107 // For LabelLookAheadMatcher, save relabeling data to file?
108 constexpr uint32 kLookAheadKeepRelabelData = 0x00000800;
109
110 // Flags used for lookahead matchers.
111 constexpr uint32 kLookAheadFlags = 0x00000ff0;
112
113 // LookAhead Matcher interface, templated on the Arc definition; used
114 // for lookahead matcher specializations that are returned by the
115 // InitMatcher() Fst method.
116 template <class Arc>
117 class LookAheadMatcherBase : public MatcherBase<Arc> {
118  public:
119   using Label = typename Arc::Label;
120   using StateId = typename Arc::StateId;
121   using Weight = typename Arc::Weight;
122
123   virtual void InitLookAheadFst(const Fst<Arc> &, bool copy = false) = 0;
124   virtual bool LookAheadFst(const Fst<Arc> &, StateId) = 0;
125   virtual bool LookAheadLabel(Label) const = 0;
126
127   // Suggested concrete implementation of lookahead methods.
128
129   bool LookAheadPrefix(Arc *arc) const {
130     if (prefix_arc_.nextstate != kNoStateId) {
131       *arc = prefix_arc_;
132       return true;
133     } else {
134       return false;
135     }
136   }
137
138   Weight LookAheadWeight() const { return weight_; }
139
140  protected:
141   // Concrete implementations for lookahead helper methods.
142
143   void ClearLookAheadWeight() { weight_ = Weight::One(); }
144
145   void SetLookAheadWeight(Weight weight) { weight_ = std::move(weight); }
146
147   void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
148
149   void SetLookAheadPrefix(Arc arc) { prefix_arc_ = std::move(arc); }
150
151  private:
152   Arc prefix_arc_;
153   Weight weight_;
154 };
155
156 // Doesn't actually lookahead, just declares that the future looks good.
157 template <class M>
158 class TrivialLookAheadMatcher
159     : public LookAheadMatcherBase<typename M::FST::Arc> {
160  public:
161   using FST = typename M::FST;
162   using Arc = typename FST::Arc;
163   using Label = typename Arc::Label;
164   using StateId = typename Arc::StateId;
165   using Weight = typename Arc::Weight;
166
167   TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
168       : matcher_(fst, match_type) {}
169
170   TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher,
171                           bool safe = false)
172       : matcher_(lmatcher.matcher_, safe) {}
173
174   TrivialLookAheadMatcher<M> *Copy(bool safe = false) const override {
175     return new TrivialLookAheadMatcher<M>(*this, safe);
176   }
177
178   MatchType Type(bool test) const override { return matcher_.Type(test); }
179
180   void SetState(StateId s) final { return matcher_.SetState(s); }
181
182   bool Find(Label label) final { return matcher_.Find(label); }
183
184   bool Done() const final { return matcher_.Done(); }
185
186   const Arc &Value() const final { return matcher_.Value(); }
187
188   void Next() final { matcher_.Next(); }
189
190   Weight Final(StateId s) const final { return matcher_.Final(s); }
191
192   ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
193
194   const FST &GetFst() const override { return matcher_.GetFst(); }
195
196   uint64 Properties(uint64 props) const override {
197     return matcher_.Properties(props);
198   }
199
200   uint32 Flags() const override {
201     return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
202   }
203
204   // Lookahead methods (all trivial).
205
206   void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {}
207
208   bool LookAheadFst(const Fst<Arc> &, StateId) final { return true; }
209
210   bool LookAheadLabel(Label) const final { return true; }
211
212   bool LookAheadPrefix(Arc *) const { return false; }
213
214   Weight LookAheadWeight() const { return Weight::One(); }
215
216  private:
217   M matcher_;
218 };
219
220 // Look-ahead of one transition. Template argument flags accepts flags to
221 // control behavior.
222 template <class M,
223           uint32 flags = kLookAheadNonEpsilons | kLookAheadEpsilons |
224                          kLookAheadWeight | kLookAheadPrefix>
225 class ArcLookAheadMatcher : public LookAheadMatcherBase<typename M::FST::Arc> {
226  public:
227   using FST = typename M::FST;
228   using Arc = typename FST::Arc;
229   using Label = typename Arc::Label;
230   using StateId = typename Arc::StateId;
231   using Weight = typename Arc::Weight;
232   using MatcherData = NullAddOn;
233
234   using LookAheadMatcherBase<Arc>::ClearLookAheadWeight;
235   using LookAheadMatcherBase<Arc>::LookAheadWeight;
236   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
237   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
238   using LookAheadMatcherBase<Arc>::LookAheadPrefix;
239   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
240
241   enum : uint32 { kFlags = flags };
242
243   ArcLookAheadMatcher(
244       const FST &fst, MatchType match_type,
245       std::shared_ptr<MatcherData> data = std::shared_ptr<MatcherData>())
246       : matcher_(fst, match_type),
247         fst_(matcher_.GetFst()),
248         lfst_(nullptr),
249         state_(kNoStateId) {}
250
251   ArcLookAheadMatcher(const ArcLookAheadMatcher<M, flags> &lmatcher,
252                       bool safe = false)
253       : matcher_(lmatcher.matcher_, safe),
254         fst_(matcher_.GetFst()),
255         lfst_(lmatcher.lfst_),
256         state_(kNoStateId) {}
257
258   // General matcher methods.
259   ArcLookAheadMatcher<M, flags> *Copy(bool safe = false) const override {
260     return new ArcLookAheadMatcher<M, flags>(*this, safe);
261   }
262
263   MatchType Type(bool test) const override { return matcher_.Type(test); }
264
265   void SetState(StateId s) final {
266     state_ = s;
267     matcher_.SetState(s);
268   }
269
270   bool Find(Label label) final { return matcher_.Find(label); }
271
272   bool Done() const final { return matcher_.Done(); }
273
274   const Arc &Value() const final { return matcher_.Value(); }
275
276   void Next() final { matcher_.Next(); }
277
278   Weight Final(StateId s) const final { return matcher_.Final(s); }
279
280   ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
281
282   const FST &GetFst() const override { return fst_; }
283
284   uint64 Properties(uint64 props) const override {
285     return matcher_.Properties(props);
286   }
287
288   uint32 Flags() const override {
289     return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher |
290            kFlags;
291   }
292
293   const MatcherData *GetData() const { return nullptr; }
294
295   std::shared_ptr<MatcherData> GetSharedData() const {
296     return std::shared_ptr<MatcherData>();
297   }
298
299   // Look-ahead methods.
300
301   void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
302     lfst_ = &fst;
303   }
304
305   // Checks if there is a matching (possibly super-final) transition
306   // at (state_, s).
307   bool LookAheadFst(const Fst<Arc> &, StateId) final;
308
309   bool LookAheadLabel(Label label) const final { return matcher_.Find(label); }
310
311  private:
312   mutable M matcher_;
313   const FST &fst_;        // Matcher FST.
314   const Fst<Arc> *lfst_;  // Look-ahead FST.
315   StateId state_;         // Matcher state.
316 };
317
318 template <class M, uint32 flags>
319 bool ArcLookAheadMatcher<M, flags>::LookAheadFst(const Fst<Arc> &fst,
320                                                  StateId s) {
321   if (&fst != lfst_) InitLookAheadFst(fst);
322   bool result = false;
323   ssize_t nprefix = 0;
324   if (kFlags & kLookAheadWeight) ClearLookAheadWeight();
325   if (kFlags & kLookAheadPrefix) ClearLookAheadPrefix();
326   if (fst_.Final(state_) != Weight::Zero() &&
327       lfst_->Final(s) != Weight::Zero()) {
328     if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
329     ++nprefix;
330     if (kFlags & kLookAheadWeight) {
331       SetLookAheadWeight(
332           Plus(LookAheadWeight(), Times(fst_.Final(state_), lfst_->Final(s))));
333     }
334     result = true;
335   }
336   if (matcher_.Find(kNoLabel)) {
337     if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
338     ++nprefix;
339     if (kFlags & kLookAheadWeight) {
340       for (; !matcher_.Done(); matcher_.Next()) {
341         SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
342       }
343     }
344     result = true;
345   }
346   for (ArcIterator<Fst<Arc>> aiter(*lfst_, s); !aiter.Done(); aiter.Next()) {
347     const auto &arc = aiter.Value();
348     Label label = kNoLabel;
349     switch (matcher_.Type(false)) {
350       case MATCH_INPUT:
351         label = arc.olabel;
352         break;
353       case MATCH_OUTPUT:
354         label = arc.ilabel;
355         break;
356       default:
357         FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: Bad match type";
358         return true;
359     }
360     if (label == 0) {
361       if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
362       if (!(kFlags & kLookAheadNonEpsilonPrefix)) ++nprefix;
363       if (kFlags & kLookAheadWeight) {
364         SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
365       }
366       result = true;
367     } else if (matcher_.Find(label)) {
368       if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
369       for (; !matcher_.Done(); matcher_.Next()) {
370         ++nprefix;
371         if (kFlags & kLookAheadWeight) {
372           SetLookAheadWeight(Plus(LookAheadWeight(),
373                                   Times(arc.weight, matcher_.Value().weight)));
374         }
375         if ((kFlags & kLookAheadPrefix) && nprefix == 1)
376           SetLookAheadPrefix(arc);
377       }
378       result = true;
379     }
380   }
381   if (kFlags & kLookAheadPrefix) {
382     if (nprefix == 1) {
383       ClearLookAheadWeight();  // Avoids double counting.
384     } else {
385       ClearLookAheadPrefix();
386     }
387   }
388   return result;
389 }
390
391 // Template argument flags accepts flags to control behavior. It must include
392 // precisely one of kInputLookAheadMatcher or kOutputLookAheadMatcher.
393 template <class M,
394           uint32 flags = kLookAheadEpsilons | kLookAheadWeight |
395                          kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
396                          kLookAheadKeepRelabelData,
397           class Accumulator = DefaultAccumulator<typename M::Arc>,
398           class Reachable = LabelReachable<typename M::Arc, Accumulator>>
399 class LabelLookAheadMatcher
400     : public LookAheadMatcherBase<typename M::FST::Arc> {
401  public:
402   using FST = typename M::FST;
403   using Arc = typename FST::Arc;
404   using Label = typename Arc::Label;
405   using StateId = typename Arc::StateId;
406   using Weight = typename Arc::Weight;
407   using MatcherData = typename Reachable::Data;
408
409   using LookAheadMatcherBase<Arc>::ClearLookAheadWeight;
410   using LookAheadMatcherBase<Arc>::LookAheadWeight;
411   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
412   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
413   using LookAheadMatcherBase<Arc>::LookAheadPrefix;
414   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
415
416   enum : uint32 { kFlags = flags };
417
418   LabelLookAheadMatcher(
419       const FST &fst, MatchType match_type,
420       std::shared_ptr<MatcherData> data = std::shared_ptr<MatcherData>(),
421       Accumulator *accumulator = nullptr)
422       : matcher_(fst, match_type),
423         lfst_(nullptr),
424         state_(kNoStateId),
425         error_(false) {
426     if (!(kFlags & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
427       FSTERROR() << "LabelLookaheadMatcher: Bad matcher flags: " << kFlags;
428       error_ = true;
429     }
430     const bool reach_input = match_type == MATCH_INPUT;
431     if (data) {
432       if (reach_input == data->ReachInput()) {
433         label_reachable_.reset(new Reachable(data, accumulator));
434       }
435     } else if ((reach_input && (kFlags & kInputLookAheadMatcher)) ||
436                (!reach_input && (kFlags & kOutputLookAheadMatcher))) {
437       label_reachable_.reset(new Reachable(fst, reach_input, accumulator,
438                                            kFlags & kLookAheadKeepRelabelData));
439     }
440   }
441
442   LabelLookAheadMatcher(
443       const LabelLookAheadMatcher<M, flags, Accumulator, Reachable> &lmatcher,
444       bool safe = false)
445       : matcher_(lmatcher.matcher_, safe),
446         lfst_(lmatcher.lfst_),
447         label_reachable_(lmatcher.label_reachable_
448                              ? new Reachable(*lmatcher.label_reachable_, safe)
449                              : nullptr),
450         state_(kNoStateId),
451         error_(lmatcher.error_) {}
452
453   LabelLookAheadMatcher<M, flags, Accumulator, Reachable> *Copy(
454       bool safe = false) const override {
455     return new LabelLookAheadMatcher<M, flags, Accumulator, Reachable>(*this,
456                                                                        safe);
457   }
458
459   MatchType Type(bool test) const override { return matcher_.Type(test); }
460
461   void SetState(StateId s) final {
462     if (state_ == s) return;
463     state_ = s;
464     match_set_state_ = false;
465     reach_set_state_ = false;
466   }
467
468   bool Find(Label label) final {
469     if (!match_set_state_) {
470       matcher_.SetState(state_);
471       match_set_state_ = true;
472     }
473     return matcher_.Find(label);
474   }
475
476   bool Done() const final { return matcher_.Done(); }
477
478   const Arc &Value() const final { return matcher_.Value(); }
479
480   void Next() final { matcher_.Next(); }
481
482   Weight Final(StateId s) const final { return matcher_.Final(s); }
483
484   ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
485
486   const FST &GetFst() const override { return matcher_.GetFst(); }
487
488   uint64 Properties(uint64 inprops) const override {
489     auto outprops = matcher_.Properties(inprops);
490     if (error_ || (label_reachable_ && label_reachable_->Error())) {
491       outprops |= kError;
492     }
493     return outprops;
494   }
495
496   uint32 Flags() const override {
497     if (label_reachable_ && label_reachable_->GetData()->ReachInput()) {
498       return matcher_.Flags() | kFlags | kInputLookAheadMatcher;
499     } else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) {
500       return matcher_.Flags() | kFlags | kOutputLookAheadMatcher;
501     } else {
502       return matcher_.Flags();
503     }
504   }
505
506   const MatcherData *GetData() const {
507     return label_reachable_ ? label_reachable_->GetData() : nullptr;
508   };
509
510   std::shared_ptr<MatcherData> GetSharedData() const {
511     return label_reachable_ ? label_reachable_->GetSharedData()
512                             : std::shared_ptr<MatcherData>();
513   }
514   // Checks if there is a matching (possibly super-final) transition at
515   // (state_, s).
516   template <class LFST>
517   bool LookAheadFst(const LFST &fst, StateId s);
518
519   // Required to make class concrete.
520   bool LookAheadFst(const Fst<Arc> &fst, StateId s) final {
521     return LookAheadFst<Fst<Arc>>(fst, s);
522   }
523
524   void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
525     lfst_ = &fst;
526     if (label_reachable_) {
527       const bool reach_input = Type(false) == MATCH_OUTPUT;
528       label_reachable_->ReachInit(fst, reach_input, copy);
529     }
530   }
531
532   template <class LFST>
533   void InitLookAheadFst(const LFST &fst, bool copy = false) {
534     lfst_ = static_cast<const Fst<Arc> *>(&fst);
535     if (label_reachable_) {
536       const bool reach_input = Type(false) == MATCH_OUTPUT;
537       label_reachable_->ReachInit(fst, reach_input, copy);
538     }
539   }
540
541   bool LookAheadLabel(Label label) const final {
542     if (label == 0) return true;
543     if (label_reachable_) {
544       if (!reach_set_state_) {
545         label_reachable_->SetState(state_);
546         reach_set_state_ = true;
547       }
548       return label_reachable_->Reach(label);
549     } else {
550       return true;
551     }
552   }
553
554  private:
555   mutable M matcher_;
556   const Fst<Arc> *lfst_;                        // Look-ahead FST.
557   std::unique_ptr<Reachable> label_reachable_;  // Label reachability info.
558   StateId state_;                               // Matcher state.
559   bool match_set_state_;                        // matcher_.SetState called?
560   mutable bool reach_set_state_;                // reachable_.SetState called?
561   bool error_;                                  // Error encountered?
562 };
563
564 template <class M, uint32 flags, class Accumulator, class Reachable>
565 template <class LFST>
566 inline bool LabelLookAheadMatcher<M, flags, Accumulator,
567                                   Reachable>::LookAheadFst(const LFST &fst,
568                                                            StateId s) {
569   if (static_cast<const Fst<Arc> *>(&fst) != lfst_) InitLookAheadFst(fst);
570   ClearLookAheadWeight();
571   ClearLookAheadPrefix();
572   if (!label_reachable_) return true;
573   label_reachable_->SetState(state_, s);
574   reach_set_state_ = true;
575   bool compute_weight = kFlags & kLookAheadWeight;
576   bool compute_prefix = kFlags & kLookAheadPrefix;
577   ArcIterator<LFST> aiter(fst, s);
578   aiter.SetFlags(kArcNoCache, kArcNoCache);  // Makes caching optional.
579   const bool reach_arc = label_reachable_->Reach(
580       &aiter, 0, internal::NumArcs(*lfst_, s), compute_weight);
581   const auto lfinal = internal::Final(*lfst_, s);
582   const bool reach_final =
583       lfinal != Weight::Zero() && label_reachable_->ReachFinal();
584   if (reach_arc) {
585     const auto begin = label_reachable_->ReachBegin();
586     const auto end = label_reachable_->ReachEnd();
587     if (compute_prefix && end - begin == 1 && !reach_final) {
588       aiter.Seek(begin);
589       SetLookAheadPrefix(aiter.Value());
590       compute_weight = false;
591     } else if (compute_weight) {
592       SetLookAheadWeight(label_reachable_->ReachWeight());
593     }
594   }
595   if (reach_final && compute_weight) {
596     SetLookAheadWeight(reach_arc ? Plus(LookAheadWeight(), lfinal) : lfinal);
597   }
598   return reach_arc || reach_final;
599 }
600
601 // Label-lookahead relabeling class.
602 template <class Arc, class Data = LabelReachableData<typename Arc::Label>>
603 class LabelLookAheadRelabeler {
604  public:
605   using Label = typename Arc::Label;
606   using Reachable = LabelReachable<Arc, DefaultAccumulator<Arc>, Data>;
607
608   // Relabels matcher FST (initialization function object).
609   template <typename Impl>
610   explicit LabelLookAheadRelabeler(std::shared_ptr<Impl> *impl);
611
612   // Relabels arbitrary FST. Class LFST should be a label-lookahead FST.
613   template <class LFST>
614   static void Relabel(MutableFst<Arc> *fst, const LFST &mfst,
615                       bool relabel_input) {
616     const auto *data = mfst.GetAddOn();
617     Reachable reachable(data->First() ? data->SharedFirst()
618                                       : data->SharedSecond());
619     reachable.Relabel(fst, relabel_input);
620   }
621
622   // Returns relabeling pairs (cf. relabel.h::Relabel()). Class LFST should be a
623   // label-lookahead FST. If avoid_collisions is true, extra pairs are added to
624   // ensure no collisions when relabeling automata that have labels unseen here.
625   template <class LFST>
626   static void RelabelPairs(const LFST &mfst,
627                            std::vector<std::pair<Label, Label>> *pairs,
628                            bool avoid_collisions = false) {
629     const auto *data = mfst.GetAddOn();
630     Reachable reachable(data->First() ? data->SharedFirst()
631                                       : data->SharedSecond());
632     reachable.RelabelPairs(pairs, avoid_collisions);
633   }
634 };
635
636 template <class Arc, class Data>
637 template <typename Impl>
638 inline LabelLookAheadRelabeler<Arc, Data>::LabelLookAheadRelabeler(
639     std::shared_ptr<Impl> *impl) {
640   Fst<Arc> &fst = (*impl)->GetFst();
641   auto data = (*impl)->GetSharedAddOn();
642   const auto name = (*impl)->Type();
643   const bool is_mutable = fst.Properties(kMutable, false);
644   std::unique_ptr<MutableFst<Arc>> mfst;
645   if (is_mutable) {
646     mfst.reset(static_cast<MutableFst<Arc> *>(&fst));
647   } else {
648     mfst.reset(new VectorFst<Arc>(fst));
649   }
650   if (data->First()) {  // reach_input.
651     Reachable reachable(data->SharedFirst());
652     reachable.Relabel(mfst.get(), true);
653     if (!FLAGS_save_relabel_ipairs.empty()) {
654       std::vector<std::pair<Label, Label>> pairs;
655       reachable.RelabelPairs(&pairs, true);
656       WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
657     }
658   } else {
659     Reachable reachable(data->SharedSecond());
660     reachable.Relabel(mfst.get(), false);
661     if (!FLAGS_save_relabel_opairs.empty()) {
662       std::vector<std::pair<Label, Label>> pairs;
663       reachable.RelabelPairs(&pairs, true);
664       WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
665     }
666   }
667   if (!is_mutable) {
668     *impl = std::make_shared<Impl>(*mfst, name);
669     (*impl)->SetAddOn(data);
670   }
671 }
672
673 // Generic lookahead matcher, templated on the FST definition (a wrapper around
674 // a pointer to specific one).
675 template <class F>
676 class LookAheadMatcher {
677  public:
678   using FST = F;
679   using Arc = typename FST::Arc;
680   using Label = typename Arc::Label;
681   using StateId = typename Arc::StateId;
682   using Weight = typename Arc::Weight;
683   using LBase = LookAheadMatcherBase<Arc>;
684
685   LookAheadMatcher(const FST &fst, MatchType match_type)
686       : base_(fst.InitMatcher(match_type)) {
687     if (!base_) base_.reset(new SortedMatcher<FST>(fst, match_type));
688     lookahead_ = false;
689   }
690
691   // Takes ownership of base.
692   explicit LookAheadMatcher(MatcherBase<Arc> *base)
693       : base_(base), lookahead_(false) {}
694
695   LookAheadMatcher(const LookAheadMatcher<FST> &matcher, bool safe = false)
696       : base_(matcher.base_->Copy(safe)) {
697     lookahead_ = matcher.lookahead_;
698   }
699
700   LookAheadMatcher<FST> *Copy(bool safe = false) const {
701     return new LookAheadMatcher<FST>(*this, safe);
702   }
703
704   MatchType Type(bool test) const { return base_->Type(test); }
705
706   void SetState(StateId s) { base_->SetState(s); }
707
708   bool Find(Label label) { return base_->Find(label); }
709
710   bool Done() const { return base_->Done(); }
711
712   const Arc &Value() const { return base_->Value(); }
713
714   void Next() { base_->Next(); }
715
716   Weight Final(StateId s) const { return base_->Final(s); }
717
718   ssize_t Priority(StateId s) { return base_->Priority(s); }
719
720   const FST &GetFst() const {
721     return static_cast<const FST &>(base_->GetFst());
722   }
723
724   uint64 Properties(uint64 props) const { return base_->Properties(props); }
725
726   uint32 Flags() const { return base_->Flags(); }
727
728   bool LookAheadLabel(Label label) const {
729     if (LookAheadCheck()) {
730       return static_cast<LBase *>(base_.get())->LookAheadLabel(label);
731     } else {
732       return true;
733     }
734   }
735
736   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
737     if (LookAheadCheck()) {
738       return static_cast<LBase *>(base_.get())->LookAheadFst(fst, s);
739     } else {
740       return true;
741     }
742   }
743
744   Weight LookAheadWeight() const {
745     if (LookAheadCheck()) {
746       return static_cast<LBase *>(base_.get())->LookAheadWeight();
747     } else {
748       return Weight::One();
749     }
750   }
751
752   bool LookAheadPrefix(Arc *arc) const {
753     if (LookAheadCheck()) {
754       return static_cast<LBase *>(base_.get())->LookAheadPrefix(arc);
755     } else {
756       return false;
757     }
758   }
759
760   void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
761     if (LookAheadCheck()) {
762       static_cast<LBase *>(base_.get())->InitLookAheadFst(fst, copy);
763     }
764   }
765
766  private:
767   bool LookAheadCheck() const {
768     if (!lookahead_) {
769       lookahead_ =
770           base_->Flags() & (kInputLookAheadMatcher | kOutputLookAheadMatcher);
771       if (!lookahead_) {
772         FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
773       }
774     }
775     return lookahead_;
776   }
777
778   std::unique_ptr<MatcherBase<Arc>> base_;
779   mutable bool lookahead_;
780
781   LookAheadMatcher &operator=(const LookAheadMatcher &) = delete;
782 };
783
784 }  // namespace fst
785
786 #endif  // FST_LOOKAHEAD_MATCHER_H_