Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / arc-map.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to map over/transform arcs e.g., change semirings or
5 // implement project/invert. Consider using when operation does
6 // not change the number of arcs (except possibly superfinal arcs).
7
8 #ifndef FST_LIB_ARC_MAP_H_
9 #define FST_LIB_ARC_MAP_H_
10
11 #include <string>
12 #include <unordered_map>
13 #include <utility>
14
15 #include <fst/log.h>
16
17 #include <fst/cache.h>
18 #include <fst/mutable-fst.h>
19
20
21 namespace fst {
22
23 // Determines how final weights are mapped.
24 enum MapFinalAction {
25   // A final weight is mapped into a final weight. An error is raised if this
26   // is not possible.
27   MAP_NO_SUPERFINAL,
28   // A final weight is mapped to an arc to the superfinal state when the result
29   // cannot be represented as a final weight. The superfinal state will be
30   // added only if it is needed.
31   MAP_ALLOW_SUPERFINAL,
32   // A final weight is mapped to an arc to the superfinal state unless the
33   // result can be represented as a final weight of weight Zero(). The
34   // superfinal state is always added (if the input is not the empty FST).
35   MAP_REQUIRE_SUPERFINAL
36 };
37
38 // Determines how symbol tables are mapped.
39 enum MapSymbolsAction {
40   // Symbols should be cleared in the result by the map.
41   MAP_CLEAR_SYMBOLS,
42   // Symbols should be copied from the input FST by the map.
43   MAP_COPY_SYMBOLS,
44   // Symbols should not be modified in the result by the map itself.
45   // (They may set by the mapper).
46   MAP_NOOP_SYMBOLS
47 };
48
49 // The ArcMapper interfaces defines how arcs and final weights are mapped.
50 // This is useful for implementing operations that do not change the number of
51 // arcs (expect possibly superfinal arcs).
52 //
53 // template <class A, class B>
54 // class ArcMapper {
55 //  public:
56 //   using FromArc = A;
57 //   using ToArc = B;
58 //
59 //   // Maps an arc type FromArc to arc type ToArc.
60 //   ToArc operator()(const FromArc &arc);
61 //
62 //   // Specifies final action the mapper requires (see above).
63 //   // The mapper will be passed final weights as arcs of the form
64 //   // Arc(0, 0, weight, kNoStateId).
65 //   MapFinalAction FinalAction() const;
66 //
67 //   // Specifies input symbol table action the mapper requires (see above).
68 //   MapSymbolsAction InputSymbolsAction() const;
69 //
70 //   // Specifies output symbol table action the mapper requires (see above).
71 //   MapSymbolsAction OutputSymbolsAction() const;
72 //
73 //   // This specifies the known properties of an FST mapped by this mapper. It
74 //   takes as argument the input FSTs's known properties.
75 //   uint64 Properties(uint64 props) const;
76 // };
77 //
78 // The ArcMap functions and classes below will use the FinalAction()
79 // method of the mapper to determine how to treat final weights, e.g., whether
80 // to add a superfinal state. They will use the Properties() method to set the
81 // result FST properties.
82 //
83 // We include a various map versions below. One dimension of variation is
84 // whether the mapping mutates its input, writes to a new result FST, or is an
85 // on-the-fly FST. Another dimension is how we pass the mapper. We allow passing
86 // the mapper by pointer for cases that we need to change the state of the
87 // user's mapper.  This is the case with the EncodeMapper, which is reused
88 // during decoding. We also include map versions that pass the mapper by value
89 // or const reference when this suffices.
90
91 // Maps an arc type A using a mapper function object C, passed
92 // by pointer.  This version modifies its Fst input.
93 template <class A, class C>
94 void ArcMap(MutableFst<A> *fst, C *mapper) {
95   using FromArc = A;
96   using ToArc = A;
97   using StateId = typename FromArc::StateId;
98   using Weight = typename FromArc::Weight;
99   if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
100     fst->SetInputSymbols(nullptr);
101   }
102   if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
103     fst->SetOutputSymbols(nullptr);
104   }
105   if (fst->Start() == kNoStateId) return;
106   const auto props = fst->Properties(kFstProperties, false);
107   const auto final_action = mapper->FinalAction();
108   auto superfinal = kNoStateId;
109   if (final_action == MAP_REQUIRE_SUPERFINAL) {
110     superfinal = fst->AddState();
111     fst->SetFinal(superfinal, Weight::One());
112   }
113   for (StateIterator<MutableFst<FromArc>> siter(*fst); !siter.Done();
114        siter.Next()) {
115     const auto state = siter.Value();
116     for (MutableArcIterator<MutableFst<FromArc>> aiter(fst, state);
117          !aiter.Done(); aiter.Next()) {
118       const auto &arc = aiter.Value();
119       aiter.SetValue((*mapper)(arc));
120     }
121     switch (final_action) {
122       case MAP_NO_SUPERFINAL:
123       default: {
124         const FromArc arc(0, 0, fst->Final(state), kNoStateId);
125         const auto final_arc = (*mapper)(arc);
126         if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
127           FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc";
128           fst->SetProperties(kError, kError);
129         }
130         fst->SetFinal(state, final_arc.weight);
131         break;
132       }
133       case MAP_ALLOW_SUPERFINAL: {
134         if (state != superfinal) {
135           const FromArc arc(0, 0, fst->Final(state), kNoStateId);
136           auto final_arc = (*mapper)(arc);
137           if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
138             // Add a superfinal state if not already done.
139             if (superfinal == kNoStateId) {
140               superfinal = fst->AddState();
141               fst->SetFinal(superfinal, Weight::One());
142             }
143             final_arc.nextstate = superfinal;
144             fst->AddArc(state, final_arc);
145             fst->SetFinal(state, Weight::Zero());
146           } else {
147             fst->SetFinal(state, final_arc.weight);
148           }
149         }
150         break;
151       }
152       case MAP_REQUIRE_SUPERFINAL: {
153         if (state != superfinal) {
154           const FromArc arc(0, 0, fst->Final(state), kNoStateId);
155           const auto final_arc = (*mapper)(arc);
156           if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
157               final_arc.weight != Weight::Zero()) {
158             fst->AddArc(state, ToArc(final_arc.ilabel, final_arc.olabel,
159                                      final_arc.weight, superfinal));
160           }
161           fst->SetFinal(state, Weight::Zero());
162         }
163         break;
164       }
165     }
166   }
167   fst->SetProperties(mapper->Properties(props), kFstProperties);
168 }
169
170 // Maps an arc type A using a mapper function object C, passed by value. This
171 // version modifies its FST input.
172 template <class A, class C>
173 void ArcMap(MutableFst<A> *fst, C mapper) {
174   ArcMap(fst, &mapper);
175 }
176
177 // Maps an arc type A to an arc type B using mapper function object C,
178 // passed by pointer. This version writes the mapped input FST to an
179 // output MutableFst.
180 template <class A, class B, class C>
181 void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C *mapper) {
182   using FromArc = A;
183   using StateId = typename FromArc::StateId;
184   using Weight = typename FromArc::Weight;
185   ofst->DeleteStates();
186   if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
187     ofst->SetInputSymbols(ifst.InputSymbols());
188   } else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
189     ofst->SetInputSymbols(nullptr);
190   }
191   if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
192     ofst->SetOutputSymbols(ifst.OutputSymbols());
193   } else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
194     ofst->SetOutputSymbols(nullptr);
195   }
196   const auto iprops = ifst.Properties(kCopyProperties, false);
197   if (ifst.Start() == kNoStateId) {
198     if (iprops & kError) ofst->SetProperties(kError, kError);
199     return;
200   }
201   const auto final_action = mapper->FinalAction();
202   if (ifst.Properties(kExpanded, false)) {
203     ofst->ReserveStates(
204         CountStates(ifst) + final_action == MAP_NO_SUPERFINAL ? 0 : 1);
205   }
206   // Adds all states.
207   for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
208     ofst->AddState();
209   }
210   StateId superfinal = kNoStateId;
211   if (final_action == MAP_REQUIRE_SUPERFINAL) {
212     superfinal = ofst->AddState();
213     ofst->SetFinal(superfinal, B::Weight::One());
214   }
215   for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
216     StateId s = siter.Value();
217     if (s == ifst.Start()) ofst->SetStart(s);
218     ofst->ReserveArcs(s, ifst.NumArcs(s));
219     for (ArcIterator<Fst<A>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
220       ofst->AddArc(s, (*mapper)(aiter.Value()));
221     }
222     switch (final_action) {
223       case MAP_NO_SUPERFINAL:
224       default: {
225         B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
226         if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
227           FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc";
228           ofst->SetProperties(kError, kError);
229         }
230         ofst->SetFinal(s, final_arc.weight);
231         break;
232       }
233       case MAP_ALLOW_SUPERFINAL: {
234         B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
235         if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
236           // Add a superfinal state if not already done.
237           if (superfinal == kNoStateId) {
238             superfinal = ofst->AddState();
239             ofst->SetFinal(superfinal, B::Weight::One());
240           }
241           final_arc.nextstate = superfinal;
242           ofst->AddArc(s, final_arc);
243           ofst->SetFinal(s, B::Weight::Zero());
244         } else {
245           ofst->SetFinal(s, final_arc.weight);
246         }
247         break;
248       }
249       case MAP_REQUIRE_SUPERFINAL: {
250         B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
251         if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
252             final_arc.weight != B::Weight::Zero()) {
253           ofst->AddArc(s, B(final_arc.ilabel, final_arc.olabel,
254                             final_arc.weight, superfinal));
255         }
256         ofst->SetFinal(s, B::Weight::Zero());
257         break;
258       }
259     }
260   }
261   const auto oprops = ofst->Properties(kFstProperties, false);
262   ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
263 }
264
265 // Maps an arc type A to an arc type B using mapper function
266 // object C, passed by value. This version writes the mapped input
267 // Fst to an output MutableFst.
268 template <class A, class B, class C>
269 void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
270   ArcMap(ifst, ofst, &mapper);
271 }
272
273 struct ArcMapFstOptions : public CacheOptions {
274   // ArcMapFst default caching behaviour is to do no caching. Most mappers are
275   // cheap and therefore we save memory by not doing caching.
276   ArcMapFstOptions() : CacheOptions(true, 0) {}
277
278   explicit ArcMapFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
279 };
280
281 template <class A, class B, class C>
282 class ArcMapFst;
283
284 namespace internal {
285
286 // Implementation of delayed ArcMapFst.
287 template <class A, class B, class C>
288 class ArcMapFstImpl : public CacheImpl<B> {
289  public:
290   using Arc = B;
291   using StateId = typename Arc::StateId;
292   using Weight = typename Arc::Weight;
293
294   using FstImpl<B>::SetType;
295   using FstImpl<B>::SetProperties;
296   using FstImpl<B>::SetInputSymbols;
297   using FstImpl<B>::SetOutputSymbols;
298
299   using CacheImpl<B>::PushArc;
300   using CacheImpl<B>::HasArcs;
301   using CacheImpl<B>::HasFinal;
302   using CacheImpl<B>::HasStart;
303   using CacheImpl<B>::SetArcs;
304   using CacheImpl<B>::SetFinal;
305   using CacheImpl<B>::SetStart;
306
307   friend class StateIterator<ArcMapFst<A, B, C>>;
308
309   ArcMapFstImpl(const Fst<A> &fst, const C &mapper,
310                 const ArcMapFstOptions &opts)
311       : CacheImpl<B>(opts),
312         fst_(fst.Copy()),
313         mapper_(new C(mapper)),
314         own_mapper_(true),
315         superfinal_(kNoStateId),
316         nstates_(0) {
317     Init();
318   }
319
320   ArcMapFstImpl(const Fst<A> &fst, C *mapper, const ArcMapFstOptions &opts)
321       : CacheImpl<B>(opts),
322         fst_(fst.Copy()),
323         mapper_(mapper),
324         own_mapper_(false),
325         superfinal_(kNoStateId),
326         nstates_(0) {
327     Init();
328   }
329
330   ArcMapFstImpl(const ArcMapFstImpl<A, B, C> &impl)
331       : CacheImpl<B>(impl),
332         fst_(impl.fst_->Copy(true)),
333         mapper_(new C(*impl.mapper_)),
334         own_mapper_(true),
335         superfinal_(kNoStateId),
336         nstates_(0) {
337     Init();
338   }
339
340   ~ArcMapFstImpl() override {
341     if (own_mapper_) delete mapper_;
342   }
343
344   StateId Start() {
345     if (!HasStart()) SetStart(FindOState(fst_->Start()));
346     return CacheImpl<B>::Start();
347   }
348
349   Weight Final(StateId s) {
350     if (!HasFinal(s)) {
351       switch (final_action_) {
352         case MAP_NO_SUPERFINAL:
353         default: {
354           const auto final_arc =
355               (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
356           if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
357             FSTERROR() << "ArcMapFst: Non-zero arc labels for superfinal arc";
358             SetProperties(kError, kError);
359           }
360           SetFinal(s, final_arc.weight);
361           break;
362         }
363         case MAP_ALLOW_SUPERFINAL: {
364           if (s == superfinal_) {
365             SetFinal(s, Weight::One());
366           } else {
367             const auto final_arc =
368                 (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
369             if (final_arc.ilabel == 0 && final_arc.olabel == 0) {
370               SetFinal(s, final_arc.weight);
371             } else {
372               SetFinal(s, Weight::Zero());
373             }
374           }
375           break;
376         }
377         case MAP_REQUIRE_SUPERFINAL: {
378           SetFinal(s, s == superfinal_ ? Weight::One() : Weight::Zero());
379           break;
380         }
381       }
382     }
383     return CacheImpl<B>::Final(s);
384   }
385
386   size_t NumArcs(StateId s) {
387     if (!HasArcs(s)) Expand(s);
388     return CacheImpl<B>::NumArcs(s);
389   }
390
391   size_t NumInputEpsilons(StateId s) {
392     if (!HasArcs(s)) Expand(s);
393     return CacheImpl<B>::NumInputEpsilons(s);
394   }
395
396   size_t NumOutputEpsilons(StateId s) {
397     if (!HasArcs(s)) Expand(s);
398     return CacheImpl<B>::NumOutputEpsilons(s);
399   }
400
401   uint64 Properties() const override { return Properties(kFstProperties); }
402
403   // Sets error if found, and returns other FST impl properties.
404   uint64 Properties(uint64 mask) const override {
405     if ((mask & kError) && (fst_->Properties(kError, false) ||
406                             (mapper_->Properties(0) & kError))) {
407       SetProperties(kError, kError);
408     }
409     return FstImpl<Arc>::Properties(mask);
410   }
411
412   void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
413     if (!HasArcs(s)) Expand(s);
414     CacheImpl<B>::InitArcIterator(s, data);
415   }
416
417   void Expand(StateId s) {
418     // Add exiting arcs.
419     if (s == superfinal_) {
420       SetArcs(s);
421       return;
422     }
423     for (ArcIterator<Fst<A>> aiter(*fst_, FindIState(s)); !aiter.Done();
424          aiter.Next()) {
425       auto aarc = aiter.Value();
426       aarc.nextstate = FindOState(aarc.nextstate);
427       const auto &barc = (*mapper_)(aarc);
428       PushArc(s, barc);
429     }
430
431     // Check for superfinal arcs.
432     if (!HasFinal(s) || Final(s) == Weight::Zero()) {
433       switch (final_action_) {
434         case MAP_NO_SUPERFINAL:
435         default:
436           break;
437         case MAP_ALLOW_SUPERFINAL: {
438           auto final_arc =
439               (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
440           if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
441             if (superfinal_ == kNoStateId) superfinal_ = nstates_++;
442             final_arc.nextstate = superfinal_;
443             PushArc(s, final_arc);
444           }
445           break;
446         }
447         case MAP_REQUIRE_SUPERFINAL: {
448           const auto final_arc =
449               (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
450           if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
451               final_arc.weight != B::Weight::Zero()) {
452             PushArc(s, B(final_arc.ilabel, final_arc.olabel, final_arc.weight,
453                          superfinal_));
454           }
455           break;
456         }
457       }
458     }
459     SetArcs(s);
460   }
461
462  private:
463   void Init() {
464     SetType("map");
465     if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
466       SetInputSymbols(fst_->InputSymbols());
467     } else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
468       SetInputSymbols(nullptr);
469     }
470     if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
471       SetOutputSymbols(fst_->OutputSymbols());
472     } else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
473       SetOutputSymbols(nullptr);
474     }
475     if (fst_->Start() == kNoStateId) {
476       final_action_ = MAP_NO_SUPERFINAL;
477       SetProperties(kNullProperties);
478     } else {
479       final_action_ = mapper_->FinalAction();
480       uint64 props = fst_->Properties(kCopyProperties, false);
481       SetProperties(mapper_->Properties(props));
482       if (final_action_ == MAP_REQUIRE_SUPERFINAL) superfinal_ = 0;
483     }
484   }
485
486   // Maps from output state to input state.
487   StateId FindIState(StateId s) {
488     if (superfinal_ == kNoStateId || s < superfinal_) {
489       return s;
490     } else {
491       return s - 1;
492     }
493   }
494
495   // Maps from input state to output state.
496   StateId FindOState(StateId is) {
497     auto os = is;
498     if (!(superfinal_ == kNoStateId || is < superfinal_)) ++os;
499     if (os >= nstates_) nstates_ = os + 1;
500     return os;
501   }
502
503   std::unique_ptr<const Fst<A>> fst_;
504   C *mapper_;
505   const bool own_mapper_;
506   MapFinalAction final_action_;
507   StateId superfinal_;
508   StateId nstates_;
509 };
510
511 }  // namespace internal
512
513 // Maps an arc type A to an arc type B using Mapper function object
514 // C. This version is a delayed FST.
515 template <class A, class B, class C>
516 class ArcMapFst : public ImplToFst<internal::ArcMapFstImpl<A, B, C>> {
517  public:
518   using Arc = B;
519   using StateId = typename Arc::StateId;
520   using Weight = typename Arc::Weight;
521
522   using Store = DefaultCacheStore<B>;
523   using State = typename Store::State;
524   using Impl = internal::ArcMapFstImpl<A, B, C>;
525
526   friend class ArcIterator<ArcMapFst<A, B, C>>;
527   friend class StateIterator<ArcMapFst<A, B, C>>;
528
529   ArcMapFst(const Fst<A> &fst, const C &mapper, const ArcMapFstOptions &opts)
530       : ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
531
532   ArcMapFst(const Fst<A> &fst, C *mapper, const ArcMapFstOptions &opts)
533       : ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
534
535   ArcMapFst(const Fst<A> &fst, const C &mapper)
536       : ImplToFst<Impl>(
537             std::make_shared<Impl>(fst, mapper, ArcMapFstOptions())) {}
538
539   ArcMapFst(const Fst<A> &fst, C *mapper)
540       : ImplToFst<Impl>(
541             std::make_shared<Impl>(fst, mapper, ArcMapFstOptions())) {}
542
543   // See Fst<>::Copy() for doc.
544   ArcMapFst(const ArcMapFst<A, B, C> &fst, bool safe = false)
545       : ImplToFst<Impl>(fst, safe) {}
546
547   // Get a copy of this ArcMapFst. See Fst<>::Copy() for further doc.
548   ArcMapFst<A, B, C> *Copy(bool safe = false) const override {
549     return new ArcMapFst<A, B, C>(*this, safe);
550   }
551
552   inline void InitStateIterator(StateIteratorData<B> *data) const override;
553
554   void InitArcIterator(StateId s, ArcIteratorData<B> *data) const override {
555     GetMutableImpl()->InitArcIterator(s, data);
556   }
557
558  protected:
559   using ImplToFst<Impl>::GetImpl;
560   using ImplToFst<Impl>::GetMutableImpl;
561
562  private:
563   ArcMapFst &operator=(const ArcMapFst &) = delete;
564 };
565
566 // Specialization for ArcMapFst.
567 //
568 // This may be derived from.
569 template <class A, class B, class C>
570 class StateIterator<ArcMapFst<A, B, C>> : public StateIteratorBase<B> {
571  public:
572   using StateId = typename B::StateId;
573
574   explicit StateIterator(const ArcMapFst<A, B, C> &fst)
575       : impl_(fst.GetImpl()),
576         siter_(*impl_->fst_),
577         s_(0),
578         superfinal_(impl_->final_action_ == MAP_REQUIRE_SUPERFINAL) {
579     CheckSuperfinal();
580   }
581
582   bool Done() const final { return siter_.Done() && !superfinal_; }
583
584   StateId Value() const final { return s_; }
585
586   void Next() final {
587     ++s_;
588     if (!siter_.Done()) {
589       siter_.Next();
590       CheckSuperfinal();
591     } else if (superfinal_) {
592       superfinal_ = false;
593     }
594   }
595
596   void Reset() final {
597     s_ = 0;
598     siter_.Reset();
599     superfinal_ = impl_->final_action_ == MAP_REQUIRE_SUPERFINAL;
600     CheckSuperfinal();
601   }
602
603  private:
604   void CheckSuperfinal() {
605     if (impl_->final_action_ != MAP_ALLOW_SUPERFINAL || superfinal_) return;
606     if (!siter_.Done()) {
607       const auto final_arc =
608           (*impl_->mapper_)(A(0, 0, impl_->fst_->Final(s_), kNoStateId));
609       if (final_arc.ilabel != 0 || final_arc.olabel != 0) superfinal_ = true;
610     }
611   }
612
613   const internal::ArcMapFstImpl<A, B, C> *impl_;
614   StateIterator<Fst<A>> siter_;
615   StateId s_;
616   bool superfinal_;  // True if there is a superfinal state and not done.
617 };
618
619 // Specialization for ArcMapFst.
620 template <class A, class B, class C>
621 class ArcIterator<ArcMapFst<A, B, C>>
622     : public CacheArcIterator<ArcMapFst<A, B, C>> {
623  public:
624   using StateId = typename A::StateId;
625
626   ArcIterator(const ArcMapFst<A, B, C> &fst, StateId s)
627       : CacheArcIterator<ArcMapFst<A, B, C>>(fst.GetMutableImpl(), s) {
628     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
629   }
630 };
631
632 template <class A, class B, class C>
633 inline void ArcMapFst<A, B, C>::InitStateIterator(
634     StateIteratorData<B> *data) const {
635   data->base = new StateIterator<ArcMapFst<A, B, C>>(*this);
636 }
637
638 // Utility Mappers.
639
640 // Mapper that returns its input.
641 template <class A>
642 class IdentityArcMapper {
643  public:
644   using FromArc = A;
645   using ToArc = A;
646
647   ToArc operator()(const FromArc &arc) const { return arc; }
648
649   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
650
651   constexpr MapSymbolsAction InputSymbolsAction() const {
652     return MAP_COPY_SYMBOLS;
653   }
654
655   constexpr MapSymbolsAction OutputSymbolsAction() const {
656     return MAP_COPY_SYMBOLS;
657   }
658
659   uint64 Properties(uint64 props) const { return props; }
660 };
661
662 // Mapper that converts all input symbols to epsilon.
663 template <class A>
664 class InputEpsilonMapper {
665  public:
666   using FromArc = A;
667   using ToArc = A;
668
669   ToArc operator()(const FromArc &arc) const {
670     return ToArc(0, arc.olabel, arc.weight, arc.nextstate);
671   }
672
673   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
674
675   constexpr MapSymbolsAction InputSymbolsAction() const {
676     return MAP_CLEAR_SYMBOLS;
677   }
678
679   constexpr MapSymbolsAction OutputSymbolsAction() const {
680     return MAP_COPY_SYMBOLS;
681   }
682
683   uint64 Properties(uint64 props) const {
684     return (props & kSetArcProperties) | kIEpsilons;
685   }
686 };
687
688 // Mapper that converts all output symbols to epsilon.
689 template <class A>
690 class OutputEpsilonMapper {
691  public:
692   using FromArc = A;
693   using ToArc = A;
694
695   ToArc operator()(const FromArc &arc) const {
696     return ToArc(arc.ilabel, 0, arc.weight, arc.nextstate);
697   }
698
699   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
700
701   constexpr MapSymbolsAction InputSymbolsAction() const {
702     return MAP_COPY_SYMBOLS;
703   }
704
705   constexpr MapSymbolsAction OutputSymbolsAction() const {
706     return MAP_CLEAR_SYMBOLS;
707   }
708
709   uint64 Properties(uint64 props) const {
710     return (props & kSetArcProperties) | kOEpsilons;
711   }
712 };
713
714 // Mapper that returns its input with final states redirected to a single
715 // super-final state.
716 template <class A>
717 class SuperFinalMapper {
718  public:
719   using FromArc = A;
720   using ToArc = A;
721   using Label = typename FromArc::Label;
722   using Weight = typename FromArc::Weight;;
723
724   // Arg allows setting super-final label.
725   explicit SuperFinalMapper(Label final_label = 0)
726       : final_label_(final_label) {}
727
728   ToArc operator()(const FromArc &arc) const {
729     // Super-final arc.
730     if (arc.nextstate == kNoStateId && arc.weight != Weight::Zero()) {
731       return ToArc(final_label_, final_label_, arc.weight, kNoStateId);
732     } else {
733       return arc;
734     }
735   }
736
737   constexpr MapFinalAction FinalAction() const {
738     return MAP_REQUIRE_SUPERFINAL;
739   }
740
741   constexpr MapSymbolsAction InputSymbolsAction() const {
742     return MAP_COPY_SYMBOLS;
743   }
744
745   constexpr MapSymbolsAction OutputSymbolsAction() const {
746     return MAP_COPY_SYMBOLS;
747   }
748
749   uint64 Properties(uint64 props) const {
750     if (final_label_ == 0) {
751       return props & kAddSuperFinalProperties;
752     } else {
753       return props & kAddSuperFinalProperties &
754           kILabelInvariantProperties & kOLabelInvariantProperties;
755     }
756   }
757
758  private:
759   Label final_label_;
760 };
761
762 // Mapper that leaves labels and nextstate unchanged and constructs a new weight
763 // from the underlying value of the arc weight. If no weight converter is
764 // explictly specified, requires that there is a WeightConvert class
765 // specialization that converts the weights.
766 template <class A, class B,
767           class C = WeightConvert<typename A::Weight, typename B::Weight>>
768 class WeightConvertMapper {
769  public:
770   using FromArc = A;
771   using ToArc = B;
772   using Converter = C;
773   using FromWeight = typename FromArc::Weight;
774   using ToWeight = typename ToArc::Weight;
775
776   explicit WeightConvertMapper(const Converter &c = Converter())
777       : convert_weight_(c) {}
778
779   ToArc operator()(const FromArc &arc) const {
780     return ToArc(arc.ilabel, arc.olabel, convert_weight_(arc.weight),
781                  arc.nextstate);
782   }
783
784   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
785
786   constexpr MapSymbolsAction InputSymbolsAction() const {
787     return MAP_COPY_SYMBOLS;
788   }
789
790   constexpr MapSymbolsAction OutputSymbolsAction() const {
791     return MAP_COPY_SYMBOLS;
792   }
793
794   uint64 Properties(uint64 props) const { return props; }
795
796  private:
797   Converter convert_weight_;
798 };
799
800 // Non-precision-changing weight conversions; consider using more efficient
801 // Cast method instead.
802
803 using StdToLogMapper = WeightConvertMapper<StdArc, LogArc>;
804
805 using LogToStdMapper = WeightConvertMapper<LogArc, StdArc>;
806
807 // Precision-changing weight conversions.
808
809 using StdToLog64Mapper = WeightConvertMapper<StdArc, Log64Arc>;
810
811 using LogToLog64Mapper = WeightConvertMapper<LogArc, Log64Arc>;
812
813 using Log64ToStdMapper = WeightConvertMapper<Log64Arc, StdArc>;
814
815 using Log64ToLogMapper = WeightConvertMapper<Log64Arc, LogArc>;
816
817 // Mapper from A to GallicArc<A>.
818 template <class A, GallicType G = GALLIC_LEFT>
819 class ToGallicMapper {
820  public:
821   using FromArc = A;
822   using ToArc = GallicArc<A, G>;
823
824   using SW = StringWeight<typename A::Label, GallicStringType(G)>;
825   using AW = typename FromArc::Weight;
826   using GW = typename ToArc::Weight;
827
828   ToArc operator()(const FromArc &arc) const {
829     // Super-final arc.
830     if (arc.nextstate == kNoStateId && arc.weight != AW::Zero()) {
831       return ToArc(0, 0, GW(SW::One(), arc.weight), kNoStateId);
832     // Super-non-final arc.
833     } else if (arc.nextstate == kNoStateId) {
834       return ToArc(0, 0, GW::Zero(), kNoStateId);
835     // Epsilon label.
836     } else if (arc.olabel == 0) {
837       return ToArc(arc.ilabel, arc.ilabel, GW(SW::One(), arc.weight),
838                    arc.nextstate);
839     // Regular label.
840     } else {
841       return ToArc(arc.ilabel, arc.ilabel, GW(SW(arc.olabel), arc.weight),
842                    arc.nextstate);
843     }
844   }
845
846   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
847
848   constexpr MapSymbolsAction InputSymbolsAction() const {
849     return MAP_COPY_SYMBOLS;
850   }
851
852   constexpr MapSymbolsAction OutputSymbolsAction() const {
853     return MAP_CLEAR_SYMBOLS;
854   }
855
856   uint64 Properties(uint64 props) const {
857     return ProjectProperties(props, true) & kWeightInvariantProperties;
858   }
859 };
860
861 // Mapper from GallicArc<A> to A.
862 template <class A, GallicType G = GALLIC_LEFT>
863 class FromGallicMapper {
864  public:
865   using FromArc = GallicArc<A, G>;
866   using ToArc = A;
867
868   using Label = typename ToArc::Label;
869   using AW = typename ToArc::Weight;
870   using GW = typename FromArc::Weight;
871
872   explicit FromGallicMapper(Label superfinal_label = 0)
873       : superfinal_label_(superfinal_label), error_(false) {}
874
875   ToArc operator()(const FromArc &arc) const {
876     // 'Super-non-final' arc.
877     if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) {
878       return A(arc.ilabel, 0, AW::Zero(), kNoStateId);
879     }
880     Label l = kNoLabel;
881     AW weight;
882     if (!Extract(arc.weight, &weight, &l) || arc.ilabel != arc.olabel) {
883       FSTERROR() << "FromGallicMapper: Unrepresentable weight: " << arc.weight
884                  << " for arc with ilabel = " << arc.ilabel
885                  << ", olabel = " << arc.olabel
886                  << ", nextstate = " << arc.nextstate;
887       error_ = true;
888     }
889     if (arc.ilabel == 0 && l != 0 && arc.nextstate == kNoStateId) {
890       return ToArc(superfinal_label_, l, weight, arc.nextstate);
891     } else {
892       return ToArc(arc.ilabel, l, weight, arc.nextstate);
893     }
894   }
895
896   constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; }
897
898   constexpr MapSymbolsAction InputSymbolsAction() const {
899     return MAP_COPY_SYMBOLS;
900   }
901
902   constexpr MapSymbolsAction OutputSymbolsAction() const {
903     return MAP_CLEAR_SYMBOLS;
904   }
905
906   uint64 Properties(uint64 inprops) const {
907     uint64 outprops = inprops & kOLabelInvariantProperties &
908                       kWeightInvariantProperties & kAddSuperFinalProperties;
909     if (error_) outprops |= kError;
910     return outprops;
911   }
912
913  private:
914   template <GallicType GT>
915   static bool Extract(const GallicWeight<Label, AW, GT> &gallic_weight,
916                       typename A::Weight *weight, typename A::Label *label) {
917     using GW = StringWeight<Label, GallicStringType(GT)>;
918     const GW &w1 = gallic_weight.Value1();
919     const AW &w2 = gallic_weight.Value2();
920     typename GW::Iterator iter1(w1);
921     const Label l = w1.Size() == 1 ? iter1.Value() : 0;
922     if (l == kStringInfinity || l == kStringBad || w1.Size() > 1) return false;
923     *label = l;
924     *weight = w2;
925     return true;
926   }
927
928   static bool Extract(const GallicWeight<Label, AW, GALLIC> &gallic_weight,
929                       typename A::Weight *weight, typename A::Label *label) {
930     if (gallic_weight.Size() > 1) return false;
931     if (gallic_weight.Size() == 0) {
932       *label = 0;
933       *weight = A::Weight::Zero();
934       return true;
935     }
936     return Extract<GALLIC_RESTRICT>(gallic_weight.Back(), weight, label);
937   }
938
939   const Label superfinal_label_;
940   mutable bool error_;
941 };
942
943 // Mapper from GallicArc<A> to A.
944 template <class A, GallicType G = GALLIC_LEFT>
945 class GallicToNewSymbolsMapper {
946  public:
947   using FromArc = GallicArc<A, G>;
948   using ToArc = A;
949
950   using Label = typename ToArc::Label;
951   using StateId = typename ToArc::StateId;
952   using AW = typename ToArc::Weight;
953   using GW = typename FromArc::Weight;
954   using SW = StringWeight<Label, GallicStringType(G)>;
955
956   explicit GallicToNewSymbolsMapper(MutableFst<ToArc> *fst)
957       : fst_(fst),
958         lmax_(0),
959         osymbols_(fst->OutputSymbols()),
960         isymbols_(nullptr),
961         error_(false) {
962     fst_->DeleteStates();
963     state_ = fst_->AddState();
964     fst_->SetStart(state_);
965     fst_->SetFinal(state_, AW::One());
966     if (osymbols_) {
967       string name = osymbols_->Name() + "_from_gallic";
968       fst_->SetInputSymbols(new SymbolTable(name));
969       isymbols_ = fst_->MutableInputSymbols();
970       const int64 zero = 0;
971       isymbols_->AddSymbol(osymbols_->Find(zero), 0);
972     } else {
973       fst_->SetInputSymbols(nullptr);
974     }
975   }
976
977   ToArc operator()(const FromArc &arc) {
978     // Super-non-final arc.
979     if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) {
980       return ToArc(arc.ilabel, 0, AW::Zero(), kNoStateId);
981     }
982     SW w1 = arc.weight.Value1();
983     AW w2 = arc.weight.Value2();
984     Label l;
985     if (w1.Size() == 0) {
986       l = 0;
987     } else {
988       auto insert_result = map_.insert(std::make_pair(w1, kNoLabel));
989       if (!insert_result.second) {
990         l = insert_result.first->second;
991       } else {
992         l = ++lmax_;
993         insert_result.first->second = l;
994         StringWeightIterator<SW> iter1(w1);
995         StateId n;
996         string s;
997         for (size_t i = 0, p = state_; i < w1.Size();
998              ++i, iter1.Next(), p = n) {
999           n = i == w1.Size() - 1 ? state_ : fst_->AddState();
1000           fst_->AddArc(p, ToArc(i ? 0 : l, iter1.Value(), AW::One(), n));
1001           if (isymbols_) {
1002             if (i) s = s + "_";
1003             s = s + osymbols_->Find(iter1.Value());
1004           }
1005         }
1006         if (isymbols_) isymbols_->AddSymbol(s, l);
1007       }
1008     }
1009     if (l == kStringInfinity || l == kStringBad || arc.ilabel != arc.olabel) {
1010       FSTERROR() << "GallicToNewSymbolMapper: Unrepresentable weight: " << l;
1011       error_ = true;
1012     }
1013     return ToArc(arc.ilabel, l, w2, arc.nextstate);
1014   }
1015
1016   constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; }
1017
1018   constexpr MapSymbolsAction InputSymbolsAction() const {
1019     return MAP_COPY_SYMBOLS;
1020   }
1021
1022   constexpr MapSymbolsAction OutputSymbolsAction() const {
1023     return MAP_CLEAR_SYMBOLS;
1024   }
1025
1026   uint64 Properties(uint64 inprops) const {
1027     uint64 outprops = inprops & kOLabelInvariantProperties &
1028                       kWeightInvariantProperties & kAddSuperFinalProperties;
1029     if (error_) outprops |= kError;
1030     return outprops;
1031   }
1032
1033  private:
1034   class StringKey {
1035    public:
1036     size_t operator()(const SW &x) const { return x.Hash(); }
1037   };
1038
1039   using Map = std::unordered_map<SW, Label, StringKey>;
1040
1041   MutableFst<ToArc> *fst_;
1042   Map map_;
1043   Label lmax_;
1044   StateId state_;
1045   const SymbolTable *osymbols_;
1046   SymbolTable *isymbols_;
1047   mutable bool error_;
1048 };
1049
1050 // Mapper to add a constant to all weights.
1051 template <class A>
1052 class PlusMapper {
1053  public:
1054   using FromArc = A;
1055   using ToArc = A;
1056   using Weight = typename FromArc::Weight;
1057
1058   explicit PlusMapper(Weight weight) : weight_(std::move(weight)) {}
1059
1060   ToArc operator()(const FromArc &arc) const {
1061     if (arc.weight == Weight::Zero()) return arc;
1062     return ToArc(arc.ilabel, arc.olabel, Plus(arc.weight, weight_),
1063                  arc.nextstate);
1064   }
1065
1066   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
1067
1068   constexpr MapSymbolsAction InputSymbolsAction() const {
1069     return MAP_COPY_SYMBOLS;
1070   }
1071
1072   constexpr MapSymbolsAction OutputSymbolsAction() const {
1073     return MAP_COPY_SYMBOLS;
1074   }
1075
1076   uint64 Properties(uint64 props) const {
1077     return props & kWeightInvariantProperties;
1078   }
1079
1080  private:
1081   const Weight weight_;
1082 };
1083
1084 // Mapper to (right) multiply a constant to all weights.
1085 template <class A>
1086 class TimesMapper {
1087  public:
1088   using FromArc = A;
1089   using ToArc = A;
1090   using Weight = typename FromArc::Weight;
1091
1092   explicit TimesMapper(Weight weight) : weight_(std::move(weight)) {}
1093
1094   ToArc operator()(const FromArc &arc) const {
1095     if (arc.weight == Weight::Zero()) return arc;
1096     return ToArc(arc.ilabel, arc.olabel, Times(arc.weight, weight_),
1097                  arc.nextstate);
1098   }
1099
1100   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
1101
1102   constexpr MapSymbolsAction InputSymbolsAction() const {
1103     return MAP_COPY_SYMBOLS;
1104   }
1105
1106   constexpr MapSymbolsAction OutputSymbolsAction() const {
1107     return MAP_COPY_SYMBOLS;
1108   }
1109
1110   uint64 Properties(uint64 props) const {
1111     return props & kWeightInvariantProperties;
1112   }
1113
1114  private:
1115   const Weight weight_;
1116 };
1117
1118 // Mapper to take all arc-weights to a fixed power.
1119 template <class A>
1120 class PowerMapper {
1121  public:
1122   using FromArc = A;
1123   using ToArc = A;
1124   using Weight = typename FromArc::Weight;
1125
1126   explicit PowerMapper(size_t power) : power_(power) {}
1127
1128   ToArc operator()(const FromArc &arc) const {
1129     return ToArc(arc.ilabel, arc.olabel, Power(arc.weight, power_),
1130                  arc.nextstate);
1131   }
1132
1133   MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
1134
1135   MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
1136
1137   MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
1138
1139   uint64 Properties(uint64 props) const {
1140     return props & kWeightInvariantProperties;
1141   }
1142
1143  private:
1144   size_t power_;
1145 };
1146
1147 // Mapper to reciprocate all non-Zero() weights.
1148 template <class A>
1149 class InvertWeightMapper {
1150  public:
1151   using FromArc = A;
1152   using ToArc = A;
1153   using Weight = typename FromArc::Weight;
1154
1155   ToArc operator()(const FromArc &arc) const {
1156     if (arc.weight == Weight::Zero()) return arc;
1157     return ToArc(arc.ilabel, arc.olabel, Divide(Weight::One(), arc.weight),
1158                  arc.nextstate);
1159   }
1160
1161   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
1162
1163   constexpr MapSymbolsAction InputSymbolsAction() const {
1164     return MAP_COPY_SYMBOLS;
1165   }
1166
1167   constexpr MapSymbolsAction OutputSymbolsAction() const {
1168     return MAP_COPY_SYMBOLS;
1169   }
1170
1171   uint64 Properties(uint64 props) const {
1172     return props & kWeightInvariantProperties;
1173   }
1174 };
1175
1176 // Mapper to map all non-Zero() weights to One().
1177 template <class A, class B = A>
1178 class RmWeightMapper {
1179  public:
1180   using FromArc = A;
1181   using ToArc = B;
1182   using FromWeight = typename FromArc::Weight;
1183   using ToWeight = typename ToArc::Weight;
1184
1185   ToArc operator()(const FromArc &arc) const {
1186     return ToArc(arc.ilabel, arc.olabel,
1187                  arc.weight != FromWeight::Zero() ?
1188                  ToWeight::One() : ToWeight::Zero(),
1189                  arc.nextstate);
1190   }
1191
1192   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
1193
1194   constexpr MapSymbolsAction InputSymbolsAction() const {
1195     return MAP_COPY_SYMBOLS;
1196   }
1197
1198   constexpr MapSymbolsAction OutputSymbolsAction() const {
1199     return MAP_COPY_SYMBOLS;
1200   }
1201
1202   uint64 Properties(uint64 props) const {
1203     return (props & kWeightInvariantProperties) | kUnweighted;
1204   }
1205 };
1206
1207 // Mapper to quantize all weights.
1208 template <class A, class B = A>
1209 class QuantizeMapper {
1210  public:
1211   using FromArc = A;
1212   using ToArc = B;
1213   using FromWeight = typename FromArc::Weight;
1214   using ToWeight = typename ToArc::Weight;
1215
1216   QuantizeMapper() : delta_(kDelta) {}
1217
1218   explicit QuantizeMapper(float d) : delta_(d) {}
1219
1220   ToArc operator()(const FromArc &arc) const {
1221     return ToArc(arc.ilabel, arc.olabel, arc.weight.Quantize(delta_),
1222                  arc.nextstate);
1223   }
1224
1225   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
1226
1227   constexpr MapSymbolsAction InputSymbolsAction() const {
1228     return MAP_COPY_SYMBOLS;
1229   }
1230
1231   constexpr MapSymbolsAction OutputSymbolsAction() const {
1232     return MAP_COPY_SYMBOLS;
1233   }
1234
1235   uint64 Properties(uint64 props) const {
1236     return props & kWeightInvariantProperties;
1237   }
1238
1239  private:
1240   const float delta_;
1241 };
1242
1243 // Mapper from A to B under the assumption:
1244 //
1245 //    B::Weight = A::Weight::ReverseWeight
1246 //    B::Label == A::Label
1247 //    B::StateId == A::StateId
1248 //
1249 // The weight is reversed, while the label and nextstate are preserved.
1250 template <class A, class B>
1251 class ReverseWeightMapper {
1252  public:
1253   using FromArc = A;
1254   using ToArc = B;
1255
1256   ToArc operator()(const FromArc &arc) const {
1257     return ToArc(arc.ilabel, arc.olabel, arc.weight.Reverse(), arc.nextstate);
1258   }
1259
1260   constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
1261
1262   constexpr MapSymbolsAction InputSymbolsAction() const {
1263     return MAP_COPY_SYMBOLS;
1264   }
1265
1266   constexpr MapSymbolsAction OutputSymbolsAction() const {
1267     return MAP_COPY_SYMBOLS;
1268   }
1269
1270   uint64 Properties(uint64 props) const { return props; }
1271 };
1272
1273 }  // namespace fst
1274
1275 #endif  // FST_LIB_ARC_MAP_H_