Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / replace.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes for the recursive replacement of FSTs.
5
6 #ifndef FST_REPLACE_H_
7 #define FST_REPLACE_H_
8
9 #include <set>
10 #include <string>
11 #include <unordered_map>
12 #include <utility>
13 #include <vector>
14
15 #include <fst/log.h>
16
17 #include <fst/cache.h>
18 #include <fst/expanded-fst.h>
19 #include <fst/fst-decl.h>  // For optional argument declarations.
20 #include <fst/fst.h>
21 #include <fst/matcher.h>
22 #include <fst/replace-util.h>
23 #include <fst/state-table.h>
24 #include <fst/test-properties.h>
25
26 namespace fst {
27
28 // Replace state tables have the form:
29 //
30 // template <class Arc, class P>
31 // class ReplaceStateTable {
32 //  public:
33 //   using Label = typename Arc::Label Label;
34 //   using StateId = typename Arc::StateId;
35 //
36 //   using PrefixId = P;
37 //   using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
38 //   using StackPrefix = ReplaceStackPrefix<Label, StateId>;
39 //
40 //   // Required constructor.
41 //   ReplaceStateTable(
42 //       const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
43 //       Label root);
44 //
45 //   // Required copy constructor that does not copy state.
46 //   ReplaceStateTable(const ReplaceStateTable<Arc, PrefixId> &table);
47 //
48 //   // Looks up state ID by tuple, adding it if it doesn't exist.
49 //   StateId FindState(const StateTuple &tuple);
50 //
51 //   // Looks up state tuple by ID.
52 //   const StateTuple &Tuple(StateId id) const;
53 //
54 //   // Lookus up prefix ID by stack prefix, adding it if it doesn't exist.
55 //   PrefixId FindPrefixId(const StackPrefix &stack_prefix);
56 //
57 //  // Looks up stack prefix by ID.
58 //  const StackPrefix &GetStackPrefix(PrefixId id) const;
59 // };
60
61 // Tuple that uniquely defines a state in replace.
62 template <class S, class P>
63 struct ReplaceStateTuple {
64   using StateId = S;
65   using PrefixId = P;
66
67   ReplaceStateTuple(PrefixId prefix_id = -1, StateId fst_id = kNoStateId,
68                     StateId fst_state = kNoStateId)
69       : prefix_id(prefix_id), fst_id(fst_id), fst_state(fst_state) {}
70
71   PrefixId prefix_id;  // Index in prefix table.
72   StateId fst_id;      // Current FST being walked.
73   StateId fst_state;   // Current state in FST being walked (not to be
74                        // confused with the thse StateId of the combined FST).
75 };
76
77 // Equality of replace state tuples.
78 template <class StateId, class PrefixId>
79 inline bool operator==(const ReplaceStateTuple<StateId, PrefixId> &x,
80                        const ReplaceStateTuple<StateId, PrefixId> &y) {
81   return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id &&
82          x.fst_state == y.fst_state;
83 }
84
85 // Functor returning true for tuples corresponding to states in the root FST.
86 template <class StateId, class PrefixId>
87 class ReplaceRootSelector {
88  public:
89   bool operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
90     return tuple.prefix_id == 0;
91   }
92 };
93
94 // Functor for fingerprinting replace state tuples.
95 template <class StateId, class PrefixId>
96 class ReplaceFingerprint {
97  public:
98   explicit ReplaceFingerprint(const std::vector<uint64> *size_array)
99       : size_array_(size_array) {}
100
101   uint64 operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
102     return tuple.prefix_id * size_array_->back() +
103            size_array_->at(tuple.fst_id - 1) + tuple.fst_state;
104   }
105
106  private:
107   const std::vector<uint64> *size_array_;
108 };
109
110 // Useful when the fst_state uniquely define the tuple.
111 template <class StateId, class PrefixId>
112 class ReplaceFstStateFingerprint {
113  public:
114   uint64 operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
115     return tuple.fst_state;
116   }
117 };
118
119 // A generic hash function for replace state tuples.
120 template <typename S, typename P>
121 class ReplaceHash {
122  public:
123   size_t operator()(const ReplaceStateTuple<S, P>& t) const {
124     static constexpr auto prime0 = 7853;
125     static constexpr auto prime1 = 7867;
126     return t.prefix_id + t.fst_id * prime0 + t.fst_state * prime1;
127   }
128 };
129
130 // Container for stack prefix.
131 template <class Label, class StateId>
132 class ReplaceStackPrefix {
133  public:
134   struct PrefixTuple {
135     PrefixTuple(Label fst_id = kNoLabel, StateId nextstate = kNoStateId)
136         : fst_id(fst_id), nextstate(nextstate) {}
137
138     Label fst_id;
139     StateId nextstate;
140   };
141
142   ReplaceStackPrefix() {}
143
144   ReplaceStackPrefix(const ReplaceStackPrefix &other)
145       : prefix_(other.prefix_) {}
146
147   void Push(StateId fst_id, StateId nextstate) {
148     prefix_.push_back(PrefixTuple(fst_id, nextstate));
149   }
150
151   void Pop() { prefix_.pop_back(); }
152
153   const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; }
154
155   size_t Depth() const { return prefix_.size(); }
156
157  public:
158   std::vector<PrefixTuple> prefix_;
159 };
160
161 // Equality stack prefix classes.
162 template <class Label, class StateId>
163 inline bool operator==(const ReplaceStackPrefix<Label, StateId> &x,
164                        const ReplaceStackPrefix<Label, StateId> &y) {
165   if (x.prefix_.size() != y.prefix_.size()) return false;
166   for (size_t i = 0; i < x.prefix_.size(); ++i) {
167     if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
168         x.prefix_[i].nextstate != y.prefix_[i].nextstate) {
169       return false;
170     }
171   }
172   return true;
173 }
174
175 // Hash function for stack prefix to prefix id.
176 template <class Label, class StateId>
177 class ReplaceStackPrefixHash {
178  public:
179   size_t operator()(const ReplaceStackPrefix<Label, StateId> &prefix) const {
180     size_t sum = 0;
181     for (const auto &pair : prefix.prefix_) {
182       static constexpr auto prime = 7863;
183       sum += pair.fst_id + pair.nextstate * prime;
184     }
185     return sum;
186   }
187 };
188
189 // Replace state tables.
190
191 // A two-level state table for replace. Warning: calls CountStates to compute
192 // the number of states of each component FST.
193 template <class Arc, class P = ssize_t>
194 class VectorHashReplaceStateTable {
195  public:
196   using Label = typename Arc::Label;
197   using StateId = typename Arc::StateId;
198
199   using PrefixId = P;
200
201   using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
202   using StateTable =
203       VectorHashStateTable<ReplaceStateTuple<StateId, PrefixId>,
204                            ReplaceRootSelector<StateId, PrefixId>,
205                            ReplaceFstStateFingerprint<StateId, PrefixId>,
206                            ReplaceFingerprint<StateId, PrefixId>>;
207   using StackPrefix = ReplaceStackPrefix<Label, StateId>;
208   using StackPrefixTable =
209       CompactHashBiTable<PrefixId, StackPrefix,
210                          ReplaceStackPrefixHash<Label, StateId>>;
211
212   VectorHashReplaceStateTable(
213       const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
214       Label root)
215       : root_size_(0) {
216     size_array_.push_back(0);
217     for (const auto &fst_pair : fst_list) {
218       if (fst_pair.first == root) {
219         root_size_ = CountStates(*(fst_pair.second));
220         size_array_.push_back(size_array_.back());
221       } else {
222         size_array_.push_back(size_array_.back() +
223                               CountStates(*(fst_pair.second)));
224       }
225     }
226     state_table_.reset(
227         new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
228                        new ReplaceFstStateFingerprint<StateId, PrefixId>,
229                        new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
230                        root_size_, root_size_ + size_array_.back()));
231   }
232
233   VectorHashReplaceStateTable(
234       const VectorHashReplaceStateTable<Arc, PrefixId> &table)
235       : root_size_(table.root_size_),
236         size_array_(table.size_array_),
237         prefix_table_(table.prefix_table_) {
238     state_table_.reset(
239         new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
240                        new ReplaceFstStateFingerprint<StateId, PrefixId>,
241                        new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
242                        root_size_, root_size_ + size_array_.back()));
243   }
244
245   StateId FindState(const StateTuple &tuple) {
246     return state_table_->FindState(tuple);
247   }
248
249   const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); }
250
251   PrefixId FindPrefixId(const StackPrefix &prefix) {
252     return prefix_table_.FindId(prefix);
253   }
254
255   const StackPrefix& GetStackPrefix(PrefixId id) const {
256     return prefix_table_.FindEntry(id);
257   }
258
259  private:
260   StateId root_size_;
261   std::vector<uint64> size_array_;
262   std::unique_ptr<StateTable> state_table_;
263   StackPrefixTable prefix_table_;
264 };
265
266 // Default replace state table.
267 template <class Arc, class P /* = size_t */>
268 class DefaultReplaceStateTable
269     : public CompactHashStateTable<ReplaceStateTuple<typename Arc::StateId, P>,
270                                    ReplaceHash<typename Arc::StateId, P>> {
271  public:
272   using Label = typename Arc::Label;
273   using StateId = typename Arc::StateId;
274
275   using PrefixId = P;
276   using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
277   using StateTable =
278       CompactHashStateTable<StateTuple, ReplaceHash<StateId, PrefixId>>;
279   using StackPrefix = ReplaceStackPrefix<Label, StateId>;
280   using StackPrefixTable =
281       CompactHashBiTable<PrefixId, StackPrefix,
282                          ReplaceStackPrefixHash<Label, StateId>>;
283
284   using StateTable::FindState;
285   using StateTable::Tuple;
286
287   DefaultReplaceStateTable(
288       const std::vector<std::pair<Label, const Fst<Arc> *>> &, Label) {}
289
290   DefaultReplaceStateTable(const DefaultReplaceStateTable<Arc, PrefixId> &table)
291       : StateTable(), prefix_table_(table.prefix_table_) {}
292
293   PrefixId FindPrefixId(const StackPrefix &prefix) {
294     return prefix_table_.FindId(prefix);
295   }
296
297   const StackPrefix &GetStackPrefix(PrefixId id) const {
298     return prefix_table_.FindEntry(id);
299   }
300
301  private:
302   StackPrefixTable prefix_table_;
303 };
304
305 // By default ReplaceFst will copy the input label of the replace arc.
306 // The call_label_type and return_label_type options specify how to manage
307 // the labels of the call arc and the return arc of the replace FST
308 template <class Arc, class StateTable = DefaultReplaceStateTable<Arc>,
309           class CacheStore = DefaultCacheStore<Arc>>
310 struct ReplaceFstOptions : CacheImplOptions<CacheStore> {
311   using Label = typename Arc::Label;
312
313   // Index of root rule for expansion.
314   Label root;
315   // How to label call arc.
316   ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT;
317   // How to label return arc.
318   ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER;
319   // Specifies output label to put on call arc; if kNoLabel, use existing label
320   // on call arc. Otherwise, use this field as the output label.
321   Label call_output_label = kNoLabel;
322   // Specifies label to put on return arc.
323   Label return_label = 0;
324   // Take ownership of input FSTs?
325   bool take_ownership = false;
326   // Pointer to optional pre-constructed state table.
327   StateTable *state_table = nullptr;
328
329   explicit ReplaceFstOptions(const CacheImplOptions<CacheStore> &opts,
330                              Label root = kNoLabel)
331       : CacheImplOptions<CacheStore>(opts), root(root) {}
332
333   explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel)
334       : CacheImplOptions<CacheStore>(opts), root(root) {}
335
336   // FIXME(kbg): There are too many constructors here. Come up with a consistent
337   // position for call_output_label (probably the very end) so that it is
338   // possible to express all the remaining constructors with a single
339   // default-argument constructor. Also move clients off of the "backwards
340   // compatibility" constructor, for good.
341
342   explicit ReplaceFstOptions(Label root) : root(root) {}
343
344   explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
345                              ReplaceLabelType return_label_type,
346                              Label return_label)
347       : root(root),
348         call_label_type(call_label_type),
349         return_label_type(return_label_type),
350         return_label(return_label) {}
351
352   explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
353                              ReplaceLabelType return_label_type,
354                              Label call_output_label, Label return_label)
355       : root(root),
356         call_label_type(call_label_type),
357         return_label_type(return_label_type),
358         call_output_label(call_output_label),
359         return_label(return_label) {}
360
361   explicit ReplaceFstOptions(const ReplaceUtilOptions &opts)
362       : ReplaceFstOptions(opts.root, opts.call_label_type,
363                           opts.return_label_type, opts.return_label) {}
364
365   ReplaceFstOptions() : root(kNoLabel) {}
366
367   // For backwards compatibility.
368   ReplaceFstOptions(int64 root, bool epsilon_replace_arc)
369       : root(root),
370         call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER
371                                             : REPLACE_LABEL_INPUT),
372         call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {}
373 };
374
375
376 // Forward declaration.
377 template <class Arc, class StateTable, class CacheStore>
378 class ReplaceFstMatcher;
379
380 template <class Arc>
381 using FstList = std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>;
382
383 // Returns true if label type on arc results in epsilon input label.
384 inline bool EpsilonOnInput(ReplaceLabelType label_type) {
385   return label_type == REPLACE_LABEL_NEITHER ||
386          label_type == REPLACE_LABEL_OUTPUT;
387 }
388
389 // Returns true if label type on arc results in epsilon input label.
390 inline bool EpsilonOnOutput(ReplaceLabelType label_type) {
391   return label_type == REPLACE_LABEL_NEITHER ||
392          label_type == REPLACE_LABEL_INPUT;
393 }
394
395 // Returns true if for either the call or return arc ilabel != olabel.
396 template <class Label>
397 bool ReplaceTransducer(ReplaceLabelType call_label_type,
398                        ReplaceLabelType return_label_type,
399                        Label call_output_label) {
400   return call_label_type == REPLACE_LABEL_INPUT ||
401          call_label_type == REPLACE_LABEL_OUTPUT ||
402          (call_label_type == REPLACE_LABEL_BOTH &&
403           call_output_label != kNoLabel) ||
404          return_label_type == REPLACE_LABEL_INPUT ||
405          return_label_type == REPLACE_LABEL_OUTPUT;
406 }
407
408 template <class Arc>
409 uint64 ReplaceFstProperties(typename Arc::Label root_label,
410                             const FstList<Arc> &fst_list,
411                             ReplaceLabelType call_label_type,
412                             ReplaceLabelType return_label_type,
413                             typename Arc::Label call_output_label,
414                             bool *sorted_and_non_empty) {
415   using Label = typename Arc::Label;
416   std::vector<uint64> inprops;
417   bool all_ilabel_sorted = true;
418   bool all_olabel_sorted = true;
419   bool all_non_empty = true;
420   // All nonterminals are negative?
421   bool all_negative = true;
422   // All nonterminals are positive and form a dense range containing 1?
423   bool dense_range = true;
424   Label root_fst_idx = 0;
425   for (Label i = 0; i < fst_list.size(); ++i) {
426     const auto label = fst_list[i].first;
427     if (label >= 0) all_negative = false;
428     if (label > fst_list.size() || label <= 0) dense_range = false;
429     if (label == root_label) root_fst_idx = i;
430     const auto *fst = fst_list[i].second;
431     if (fst->Start() == kNoStateId) all_non_empty = false;
432     if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false;
433     if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false;
434     inprops.push_back(fst->Properties(kCopyProperties, false));
435   }
436   const auto props = ReplaceProperties(
437       inprops, root_fst_idx, EpsilonOnInput(call_label_type),
438       EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type),
439       EpsilonOnOutput(return_label_type),
440       ReplaceTransducer(call_label_type, return_label_type, call_output_label),
441       all_non_empty, all_ilabel_sorted, all_olabel_sorted,
442       all_negative || dense_range);
443   const bool sorted = props & (kILabelSorted | kOLabelSorted);
444   *sorted_and_non_empty = all_non_empty && sorted;
445   return props;
446 }
447
448 namespace internal {
449
450 // The replace implementation class supports a dynamic expansion of a recursive
451 // transition network represented as label/FST pairs with dynamic replacable
452 // arcs.
453 template <class Arc, class StateTable, class CacheStore>
454 class ReplaceFstImpl
455     : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
456  public:
457   using Label = typename Arc::Label;
458   using StateId = typename Arc::StateId;
459   using Weight = typename Arc::Weight;
460
461   using State = typename CacheStore::State;
462   using CacheImpl = CacheBaseImpl<State, CacheStore>;
463   using PrefixId = typename StateTable::PrefixId;
464   using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
465   using StackPrefix = ReplaceStackPrefix<Label, StateId>;
466   using NonTerminalHash = std::unordered_map<Label, Label>;
467
468   using FstImpl<Arc>::SetType;
469   using FstImpl<Arc>::SetProperties;
470   using FstImpl<Arc>::WriteHeader;
471   using FstImpl<Arc>::SetInputSymbols;
472   using FstImpl<Arc>::SetOutputSymbols;
473   using FstImpl<Arc>::InputSymbols;
474   using FstImpl<Arc>::OutputSymbols;
475
476   using CacheImpl::PushArc;
477   using CacheImpl::HasArcs;
478   using CacheImpl::HasFinal;
479   using CacheImpl::HasStart;
480   using CacheImpl::SetArcs;
481   using CacheImpl::SetFinal;
482   using CacheImpl::SetStart;
483
484   friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
485
486   ReplaceFstImpl(const FstList<Arc> &fst_list,
487                  const ReplaceFstOptions<Arc, StateTable, CacheStore> &opts)
488       : CacheImpl(opts),
489         call_label_type_(opts.call_label_type),
490         return_label_type_(opts.return_label_type),
491         call_output_label_(opts.call_output_label),
492         return_label_(opts.return_label),
493         state_table_(opts.state_table ? opts.state_table
494                                       : new StateTable(fst_list, opts.root)) {
495     SetType("replace");
496     // If the label is epsilon, then all replace label options are equivalent,
497     // so we set the label types to NEITHER for simplicity.
498     if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER;
499     if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER;
500     if (!fst_list.empty()) {
501       SetInputSymbols(fst_list[0].second->InputSymbols());
502       SetOutputSymbols(fst_list[0].second->OutputSymbols());
503     }
504     fst_array_.push_back(nullptr);
505     for (Label i = 0; i < fst_list.size(); ++i) {
506       const auto label = fst_list[i].first;
507       const auto *fst = fst_list[i].second;
508       nonterminal_hash_[label] = fst_array_.size();
509       nonterminal_set_.insert(label);
510       fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy());
511       if (i) {
512         if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
513           FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i
514                      << " do not match input symbols of base FST (0th FST)";
515           SetProperties(kError, kError);
516         }
517         if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
518           FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i
519                      << " do not match output symbols of base FST (0th FST)";
520           SetProperties(kError, kError);
521         }
522       }
523     }
524     const auto nonterminal = nonterminal_hash_[opts.root];
525     if ((nonterminal == 0) && (fst_array_.size() > 1)) {
526       FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label "
527                  << opts.root << " in the input tuple vector";
528       SetProperties(kError, kError);
529     }
530     root_ = (nonterminal > 0) ? nonterminal : 1;
531     bool all_non_empty_and_sorted = false;
532     SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_,
533                                        return_label_type_, call_output_label_,
534                                        &all_non_empty_and_sorted));
535     // Enables optional caching as long as sorted and all non-empty.
536     always_cache_ = !all_non_empty_and_sorted;
537     VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
538             << (always_cache_ ? "true" : "false");
539   }
540
541   ReplaceFstImpl(const ReplaceFstImpl &impl)
542       : CacheImpl(impl),
543         call_label_type_(impl.call_label_type_),
544         return_label_type_(impl.return_label_type_),
545         call_output_label_(impl.call_output_label_),
546         return_label_(impl.return_label_),
547         always_cache_(impl.always_cache_),
548         state_table_(new StateTable(*(impl.state_table_))),
549         nonterminal_set_(impl.nonterminal_set_),
550         nonterminal_hash_(impl.nonterminal_hash_),
551         root_(impl.root_) {
552     SetType("replace");
553     SetProperties(impl.Properties(), kCopyProperties);
554     SetInputSymbols(impl.InputSymbols());
555     SetOutputSymbols(impl.OutputSymbols());
556     fst_array_.reserve(impl.fst_array_.size());
557     fst_array_.emplace_back(nullptr);
558     for (Label i = 1; i < impl.fst_array_.size(); ++i) {
559       fst_array_.emplace_back(impl.fst_array_[i]->Copy(true));
560     }
561   }
562
563   // Computes the dependency graph of the replace class and returns
564   // true if the dependencies are cyclic. Cyclic dependencies will result
565   // in an un-expandable FST.
566   bool CyclicDependencies() const {
567     const ReplaceUtilOptions opts(root_);
568     ReplaceUtil<Arc> replace_util(fst_array_, nonterminal_hash_, opts);
569     return replace_util.CyclicDependencies();
570   }
571
572   StateId Start() {
573     if (!HasStart()) {
574       if (fst_array_.size() == 1) {
575         SetStart(kNoStateId);
576         return kNoStateId;
577       } else {
578         const auto fst_start = fst_array_[root_]->Start();
579         if (fst_start == kNoStateId) return kNoStateId;
580         const auto prefix = GetPrefixId(StackPrefix());
581         const auto start =
582             state_table_->FindState(StateTuple(prefix, root_, fst_start));
583         SetStart(start);
584         return start;
585       }
586     } else {
587       return CacheImpl::Start();
588     }
589   }
590
591   Weight Final(StateId s) {
592     if (HasFinal(s)) return CacheImpl::Final(s);
593     const auto &tuple = state_table_->Tuple(s);
594     auto weight = Weight::Zero();
595     if (tuple.prefix_id == 0) {
596       const auto fst_state = tuple.fst_state;
597       weight = fst_array_[tuple.fst_id]->Final(fst_state);
598     }
599     if (always_cache_ || HasArcs(s)) SetFinal(s, weight);
600     return weight;
601   }
602
603   size_t NumArcs(StateId s) {
604     if (HasArcs(s)) {
605       return CacheImpl::NumArcs(s);
606     } else if (always_cache_) {  // If always caching, expands and caches state.
607       Expand(s);
608       return CacheImpl::NumArcs(s);
609     } else {  // Otherwise computes the number of arcs without expanding.
610       const auto tuple = state_table_->Tuple(s);
611       if (tuple.fst_state == kNoStateId) return 0;
612       auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state);
613       if (ComputeFinalArc(tuple, nullptr)) ++num_arcs;
614       return num_arcs;
615     }
616   }
617
618   // Returns whether a given label is a non-terminal.
619   bool IsNonTerminal(Label label) const {
620     if (label < *nonterminal_set_.begin() ||
621         label > *nonterminal_set_.rbegin()) {
622       return false;
623     } else {
624       return nonterminal_hash_.count(label);
625     }
626     // TODO(allauzen): be smarter and take advantage of all_dense or
627     // all_negative. Also use this in ComputeArc. This would require changes to
628     // Replace so that recursing into an empty FST lead to a non co-accessible
629     // state instead of deleting the arc as done currently. The current use
630     // correct, since labels are sorted if all_non_empty is true.
631   }
632
633   size_t NumInputEpsilons(StateId s) {
634     if (HasArcs(s)) {
635       return CacheImpl::NumInputEpsilons(s);
636     } else if (always_cache_ || !Properties(kILabelSorted)) {
637       // If always caching or if the number of input epsilons is too expensive
638       // to compute without caching (i.e., not ilabel-sorted), then expands and
639       // caches state.
640       Expand(s);
641       return CacheImpl::NumInputEpsilons(s);
642     } else {
643       // Otherwise, computes the number of input epsilons without caching.
644       const auto tuple = state_table_->Tuple(s);
645       if (tuple.fst_state == kNoStateId) return 0;
646       size_t num = 0;
647       if (!EpsilonOnInput(call_label_type_)) {
648         // If EpsilonOnInput(c) is false, all input epsilon arcs
649         // are also input epsilons arcs in the underlying machine.
650         num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state);
651       } else {
652         // Otherwise, one need to consider that all non-terminal arcs
653         // in the underlying machine also become input epsilon arc.
654         ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
655         for (; !aiter.Done() && ((aiter.Value().ilabel == 0) ||
656                                  IsNonTerminal(aiter.Value().olabel));
657              aiter.Next()) {
658           ++num;
659         }
660       }
661       if (EpsilonOnInput(return_label_type_) &&
662           ComputeFinalArc(tuple, nullptr)) {
663         ++num;
664       }
665       return num;
666     }
667   }
668
669   size_t NumOutputEpsilons(StateId s) {
670     if (HasArcs(s)) {
671       return CacheImpl::NumOutputEpsilons(s);
672     } else if (always_cache_ || !Properties(kOLabelSorted)) {
673       // If always caching or if the number of output epsilons is too expensive
674       // to compute without caching (i.e., not olabel-sorted), then expands and
675       // caches state.
676       Expand(s);
677       return CacheImpl::NumOutputEpsilons(s);
678     } else {
679       // Otherwise, computes the number of output epsilons without caching.
680       const auto tuple = state_table_->Tuple(s);
681       if (tuple.fst_state == kNoStateId) return 0;
682       size_t num = 0;
683       if (!EpsilonOnOutput(call_label_type_)) {
684         // If EpsilonOnOutput(c) is false, all output epsilon arcs are also
685         // output epsilons arcs in the underlying machine.
686         num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state);
687       } else {
688         // Otherwise, one need to consider that all non-terminal arcs in the
689         // underlying machine also become output epsilon arc.
690         ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
691         for (; !aiter.Done() && ((aiter.Value().olabel == 0) ||
692                                  IsNonTerminal(aiter.Value().olabel));
693              aiter.Next()) {
694           ++num;
695         }
696       }
697       if (EpsilonOnOutput(return_label_type_) &&
698           ComputeFinalArc(tuple, nullptr)) {
699         ++num;
700       }
701       return num;
702     }
703   }
704
705   uint64 Properties() const override { return Properties(kFstProperties); }
706
707   // Sets error if found, and returns other FST impl properties.
708   uint64 Properties(uint64 mask) const override {
709     if (mask & kError) {
710       for (Label i = 1; i < fst_array_.size(); ++i) {
711         if (fst_array_[i]->Properties(kError, false)) {
712           SetProperties(kError, kError);
713         }
714       }
715     }
716     return FstImpl<Arc>::Properties(mask);
717   }
718
719   // Returns the base arc iterator, and if arcs have not been computed yet,
720   // extends and recurses for new arcs.
721   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
722     if (!HasArcs(s)) Expand(s);
723     CacheImpl::InitArcIterator(s, data);
724     // TODO(allauzen): Set behaviour of generic iterator.
725     // Warning: ArcIterator<ReplaceFst<A>>::InitCache() relies on current
726     // behaviour.
727   }
728
729   // Extends current state (walk arcs one level deep).
730   void Expand(StateId s) {
731     const auto tuple = state_table_->Tuple(s);
732     if (tuple.fst_state == kNoStateId) {  // Local FST is empty.
733       SetArcs(s);
734       return;
735     }
736     ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
737     Arc arc;
738     // Creates a final arc when needed.
739     if (ComputeFinalArc(tuple, &arc)) PushArc(s, arc);
740     // Expands all arcs leaving the state.
741     for (; !aiter.Done(); aiter.Next()) {
742       if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, arc);
743     }
744     SetArcs(s);
745   }
746
747   void Expand(StateId s, const StateTuple &tuple,
748               const ArcIteratorData<Arc> &data) {
749     if (tuple.fst_state == kNoStateId) {  // Local FST is empty.
750       SetArcs(s);
751       return;
752     }
753     ArcIterator<Fst<Arc>> aiter(data);
754     Arc arc;
755     // Creates a final arc when needed.
756     if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc);
757     // Expands all arcs leaving the state.
758     for (; !aiter.Done(); aiter.Next()) {
759       if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc);
760     }
761     SetArcs(s);
762   }
763
764   // If acpp is null, only returns true if a final arcp is required, but does
765   // not actually compute it.
766   bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp,
767                        uint32 flags = kArcValueFlags) {
768     const auto fst_state = tuple.fst_state;
769     if (fst_state == kNoStateId) return false;
770     // If state is final, pops the stack.
771     if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() &&
772         tuple.prefix_id) {
773       if (arcp) {
774         arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
775         arcp->olabel =
776             (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
777         if (flags & kArcNextStateValue) {
778           const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id);
779           const auto prefix_id = PopPrefix(stack);
780           const auto &top = stack.Top();
781           arcp->nextstate = state_table_->FindState(
782               StateTuple(prefix_id, top.fst_id, top.nextstate));
783         }
784         if (flags & kArcWeightValue) {
785           arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state);
786         }
787       }
788       return true;
789     } else {
790       return false;
791     }
792   }
793
794   // Computes an arc in the FST corresponding to one in the underlying machine.
795   // Returns false if the underlying arc corresponds to no arc in the resulting
796   // FST.
797   bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp,
798                   uint32 flags = kArcValueFlags) {
799     if (!EpsilonOnInput(call_label_type_) &&
800         (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
801       *arcp = arc;
802       return true;
803     }
804     if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() ||
805         arc.olabel > *nonterminal_set_.rbegin()) {  // Expands local FST.
806       const auto nextstate =
807           flags & kArcNextStateValue
808               ? state_table_->FindState(
809                     StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
810               : kNoStateId;
811       *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
812     } else {
813       // Checks for non-terminal.
814       const auto it = nonterminal_hash_.find(arc.olabel);
815       if (it != nonterminal_hash_.end()) {  // Recurses into non-terminal.
816         const auto nonterminal = it->second;
817         const auto nt_prefix =
818             PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id),
819                        tuple.fst_id, arc.nextstate);
820         // If the start state is valid, replace; othewise, the arc is implicitly
821         // deleted.
822         const auto nt_start = fst_array_[nonterminal]->Start();
823         if (nt_start != kNoStateId) {
824           const auto nt_nextstate = flags & kArcNextStateValue
825                                         ? state_table_->FindState(StateTuple(
826                                               nt_prefix, nonterminal, nt_start))
827                                         : kNoStateId;
828           const auto ilabel =
829               (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
830           const auto olabel =
831               (EpsilonOnOutput(call_label_type_))
832                   ? 0
833                   : ((call_output_label_ == kNoLabel) ? arc.olabel
834                                                       : call_output_label_);
835           *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate);
836         } else {
837           return false;
838         }
839       } else {
840         const auto nextstate =
841             flags & kArcNextStateValue
842                 ? state_table_->FindState(
843                       StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
844                 : kNoStateId;
845         *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
846       }
847     }
848     return true;
849   }
850
851   // Returns the arc iterator flags supported by this FST.
852   uint32 ArcIteratorFlags() const {
853     uint32 flags = kArcValueFlags;
854     if (!always_cache_) flags |= kArcNoCache;
855     return flags;
856   }
857
858   StateTable *GetStateTable() const { return state_table_.get(); }
859
860   const Fst<Arc> *GetFst(Label fst_id) const {
861     return fst_array_[fst_id].get();
862   }
863
864   Label GetFstId(Label nonterminal) const {
865     const auto it = nonterminal_hash_.find(nonterminal);
866     if (it == nonterminal_hash_.end()) {
867       FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: "
868                  << nonterminal;
869     }
870     return it->second;
871   }
872
873   // Returns true if label type on call arc results in epsilon input label.
874   bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); }
875
876  private:
877   // The unique index into stack prefix table.
878   PrefixId GetPrefixId(const StackPrefix &prefix) {
879     return state_table_->FindPrefixId(prefix);
880   }
881
882   // The prefix ID after a stack pop.
883   PrefixId PopPrefix(StackPrefix prefix) {
884     prefix.Pop();
885     return GetPrefixId(prefix);
886   }
887
888   // The prefix ID after a stack push.
889   PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
890     prefix.Push(fst_id, nextstate);
891     return GetPrefixId(prefix);
892   }
893
894   // Runtime options
895   ReplaceLabelType call_label_type_;    // How to label call arc.
896   ReplaceLabelType return_label_type_;  // How to label return arc.
897   int64 call_output_label_;  // Specifies output label to put on call arc
898   int64 return_label_;       // Specifies label to put on return arc.
899   bool always_cache_;        // Disable optional caching of arc iterator?
900
901   // State table.
902   std::unique_ptr<StateTable> state_table_;
903
904   // Replace components.
905   std::set<Label> nonterminal_set_;
906   NonTerminalHash nonterminal_hash_;
907   std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
908   Label root_;
909 };
910
911 }  // namespace internal
912
913 //
914 // ReplaceFst supports dynamic replacement of arcs in one FST with another FST.
915 // This replacement is recursive. ReplaceFst can be used to support a variety of
916 // delayed constructions such as recursive
917 // transition networks, union, or closure. It is constructed with an array of
918 // FST(s). One FST represents the root (or topology) machine. The root FST
919 // refers to other FSTs by recursively replacing arcs labeled as non-terminals
920 // with the matching non-terminal FST. Currently the ReplaceFst uses the output
921 // symbols of the arcs to determine whether the arc is a non-terminal arc or
922 // not. A non-terminal can be any label that is not a non-zero terminal label in
923 // the output alphabet.
924 //
925 // Note that the constructor uses a vector of pairs. These correspond to the
926 // tuple of non-terminal Label and corresponding FST. For example to implement
927 // the closure operation we need 2 FSTs. The first root FST is a single
928 // self-loop arc on the start state.
929 //
930 // The ReplaceFst class supports an optionally caching arc iterator.
931 //
932 // The ReplaceFst needs to be built such that it is known to be ilabel- or
933 // olabel-sorted (see usage below).
934 //
935 // Observe that Matcher<Fst<A>> will use the optionally caching arc iterator
936 // when available (the FST is ilabel-sorted and matching on the input, or the
937 // FST is olabel -orted and matching on the output).  In order to obtain the
938 // most efficient behaviour, it is recommended to set call_label_type to
939 // REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH and return_label_type to
940 // REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER. This means that the call arc
941 // does not have epsilon on the input side and the return arc has epsilon on the
942 // input side) and matching on the input side.
943 //
944 // This class attaches interface to implementation and handles reference
945 // counting, delegating most methods to ImplToFst.
946 template <class A, class T /* = DefaultReplaceStateTable<A> */,
947           class CacheStore /* = DefaultCacheStore<A> */>
948 class ReplaceFst
949     : public ImplToFst<internal::ReplaceFstImpl<A, T, CacheStore>> {
950  public:
951   using Arc = A;
952   using Label = typename Arc::Label;
953   using StateId = typename Arc::StateId;
954   using Weight = typename Arc::Weight;
955
956   using StateTable = T;
957   using Store = CacheStore;
958   using State = typename CacheStore::State;
959   using Impl = internal::ReplaceFstImpl<Arc, StateTable, CacheStore>;
960   using CacheImpl = internal::CacheBaseImpl<State, CacheStore>;
961
962   using ImplToFst<Impl>::Properties;
963
964   friend class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
965   friend class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
966   friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
967
968   ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
969              Label root)
970       : ImplToFst<Impl>(std::make_shared<Impl>(
971             fst_array, ReplaceFstOptions<Arc, StateTable, CacheStore>(root))) {}
972
973   ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
974              const ReplaceFstOptions<Arc, StateTable, CacheStore> &opts)
975       : ImplToFst<Impl>(std::make_shared<Impl>(fst_array, opts)) {}
976
977   // See Fst<>::Copy() for doc.
978   ReplaceFst(const ReplaceFst<Arc, StateTable, CacheStore> &fst,
979              bool safe = false)
980       : ImplToFst<Impl>(fst, safe) {}
981
982   // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
983   ReplaceFst<Arc, StateTable, CacheStore> *Copy(
984       bool safe = false) const override {
985     return new ReplaceFst<Arc, StateTable, CacheStore>(*this, safe);
986   }
987
988   inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
989
990   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
991     GetMutableImpl()->InitArcIterator(s, data);
992   }
993
994   MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
995     if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
996         ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
997          (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
998       return new ReplaceFstMatcher<Arc, StateTable, CacheStore>
999           (this, match_type);
1000     } else {
1001       VLOG(2) << "Not using replace matcher";
1002       return nullptr;
1003     }
1004   }
1005
1006   bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); }
1007
1008   const StateTable &GetStateTable() const {
1009     return *GetImpl()->GetStateTable();
1010   }
1011
1012   const Fst<Arc> &GetFst(Label nonterminal) const {
1013     return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1014   }
1015
1016  private:
1017   using ImplToFst<Impl>::GetImpl;
1018   using ImplToFst<Impl>::GetMutableImpl;
1019
1020   ReplaceFst &operator=(const ReplaceFst &) = delete;
1021 };
1022
1023 // Specialization for ReplaceFst.
1024 template <class Arc, class StateTable, class CacheStore>
1025 class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>
1026     : public CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1027  public:
1028   explicit StateIterator(const ReplaceFst<Arc, StateTable, CacheStore> &fst)
1029       : CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(
1030             fst, fst.GetMutableImpl()) {}
1031 };
1032
1033 // Specialization for ReplaceFst, implementing optional caching. It is be used
1034 // as follows:
1035 //
1036 //   ReplaceFst<A> replace;
1037 //   ArcIterator<ReplaceFst<A>> aiter(replace, s);
1038 //   // Note: ArcIterator< Fst<A>> is always a caching arc iterator.
1039 //   aiter.SetFlags(kArcNoCache, kArcNoCache);
1040 //   // Uses the arc iterator, no arc will be cached, no state will be expanded.
1041 //   // Arc flags can be used to decide which component of the arc need to be
1042 //   computed.
1043 //   aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1044 //   // Wants the ilabel for this arc.
1045 //   aiter.Value();  // Does not compute the destination state.
1046 //   aiter.Next();
1047 //   aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1048 //   // Wants the ilabel and next state for this arc.
1049 //   aiter.Value();  // Does compute the destination state and inserts it
1050 //                   // in the replace state table.
1051 //   // No additional arcs have been cached at this point.
1052 template <class Arc, class StateTable, class CacheStore>
1053 class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1054  public:
1055   using StateId = typename Arc::StateId;
1056
1057   using StateTuple = typename StateTable::StateTuple;
1058
1059   ArcIterator(const ReplaceFst<Arc, StateTable, CacheStore> &fst, StateId s)
1060       : fst_(fst),
1061         s_(s),
1062         pos_(0),
1063         offset_(0),
1064         flags_(kArcValueFlags),
1065         arcs_(nullptr),
1066         data_flags_(0),
1067         final_flags_(0) {
1068     cache_data_.ref_count = nullptr;
1069     local_data_.ref_count = nullptr;
1070     // If FST does not support optional caching, forces caching.
1071     if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1072         !(fst_.GetImpl()->HasArcs(s_))) {
1073       fst_.GetMutableImpl()->Expand(s_);
1074     }
1075     // If state is already cached, use cached arcs array.
1076     if (fst_.GetImpl()->HasArcs(s_)) {
1077       (fst_.GetImpl())
1078           ->internal::template CacheBaseImpl<
1079               typename CacheStore::State,
1080               CacheStore>::InitArcIterator(s_, &cache_data_);
1081       num_arcs_ = cache_data_.narcs;
1082       arcs_ = cache_data_.arcs;      // arcs_ is a pointer to the cached arcs.
1083       data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1084     } else {  // Otherwise delay decision until Value() is called.
1085       tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_);
1086       if (tuple_.fst_state == kNoStateId) {
1087         num_arcs_ = 0;
1088       } else {
1089         // The decision to cache or not to cache has been defered until Value()
1090         // or
1091         // SetFlags() is called. However, the arc iterator is set up now to be
1092         // ready for non-caching in order to keep the Value() method simple and
1093         // efficient.
1094         const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1095         rfst->InitArcIterator(tuple_.fst_state, &local_data_);
1096         // arcs_ is a pointer to the arcs in the underlying machine.
1097         arcs_ = local_data_.arcs;
1098         // Computes the final arc (but not its destination state) if a final arc
1099         // is required.
1100         bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc(
1101             tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue);
1102         // Sets the arc value flags that hold for final_arc_.
1103         final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1104         // Computes the number of arcs.
1105         num_arcs_ = local_data_.narcs;
1106         if (has_final_arc) ++num_arcs_;
1107         // Sets the offset between the underlying arc positions and the
1108         // positions
1109         // in the arc iterator.
1110         offset_ = num_arcs_ - local_data_.narcs;
1111         // Defers the decision to cache or not until Value() or SetFlags() is
1112         // called.
1113         data_flags_ = 0;
1114       }
1115     }
1116   }
1117
1118   ~ArcIterator() {
1119     if (cache_data_.ref_count) --(*cache_data_.ref_count);
1120     if (local_data_.ref_count) --(*local_data_.ref_count);
1121   }
1122
1123   void ExpandAndCache() const  {
1124     // TODO(allauzen): revisit this.
1125     // fst_.GetImpl()->Expand(s_, tuple_, local_data_);
1126     // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(s_,
1127     //                                               &cache_data_);
1128     //
1129     fst_.InitArcIterator(s_, &cache_data_);  // Expand and cache state.
1130     arcs_ = cache_data_.arcs;      // arcs_ is a pointer to the cached arcs.
1131     data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1132     offset_ = 0;                   // No offset.
1133   }
1134
1135   void Init() {
1136     if (flags_ & kArcNoCache) {  // If caching is disabled
1137       // arcs_ is a pointer to the arcs in the underlying machine.
1138       arcs_ = local_data_.arcs;
1139       // Sets the arcs value flags that hold for arcs_.
1140       data_flags_ = kArcWeightValue;
1141       if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) {
1142         data_flags_ |= kArcILabelValue;
1143       }
1144       // Sets the offset between the underlying arc positions and the positions
1145       // in the arc iterator.
1146       offset_ = num_arcs_ - local_data_.narcs;
1147     } else {
1148       ExpandAndCache();
1149     }
1150   }
1151
1152   bool Done() const { return pos_ >= num_arcs_; }
1153
1154   const Arc &Value() const {
1155     // If data_flags_ is 0, non-caching was not requested.
1156     if (!data_flags_) {
1157       // TODO(allauzen): Revisit this.
1158       if (flags_ & kArcNoCache) {
1159         // Should never happen.
1160         FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags";
1161       }
1162       ExpandAndCache();
1163     }
1164     if (pos_ - offset_ >= 0) {  // The requested arc is not the final arc.
1165       const auto &arc = arcs_[pos_ - offset_];
1166       if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1167         // If the value flags match the recquired value flags then returns the
1168         // arc.
1169         return arc;
1170       } else {
1171         // Otherwise, compute the corresponding arc on-the-fly.
1172         fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_,
1173                                           flags_ & kArcValueFlags);
1174         return arc_;
1175       }
1176     } else {  // The requested arc is the final arc.
1177       if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1178         // If the arc value flags that hold for the final arc do not match the
1179         // requested value flags, then
1180         // final_arc_ needs to be updated.
1181         fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_,
1182                                                flags_ & kArcValueFlags);
1183         final_flags_ = flags_ & kArcValueFlags;
1184       }
1185       return final_arc_;
1186     }
1187   }
1188
1189   void Next() { ++pos_; }
1190
1191   size_t Position() const { return pos_; }
1192
1193   void Reset() { pos_ = 0; }
1194
1195   void Seek(size_t pos) { pos_ = pos; }
1196
1197   uint32 Flags() const { return flags_; }
1198
1199   void SetFlags(uint32 flags, uint32 mask) {
1200     // Updates the flags taking into account what flags are supported
1201     // by the FST.
1202     flags_ &= ~mask;
1203     flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags());
1204     // If non-caching is not requested (and caching has not already been
1205     // performed), then flush data_flags_ to request caching during the next
1206     // call to Value().
1207     if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1208       if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0;
1209     }
1210     // If data_flags_ has been flushed but non-caching is requested before
1211     // calling Value(), then set up the iterator for non-caching.
1212     if ((flags & kArcNoCache) && (!data_flags_)) Init();
1213   }
1214
1215  private:
1216   const ReplaceFst<Arc, StateTable, CacheStore> &fst_;  // Reference to the FST.
1217   StateId s_;                                           // State in the FST.
1218   mutable StateTuple tuple_;  // Tuple corresponding to state_.
1219
1220   ssize_t pos_;             // Current position.
1221   mutable ssize_t offset_;  // Offset between position in iterator and in arcs_.
1222   ssize_t num_arcs_;        // Number of arcs at state_.
1223   uint32 flags_;            // Behavorial flags for the arc iterator
1224   mutable Arc arc_;         // Memory to temporarily store computed arcs.
1225
1226   mutable ArcIteratorData<Arc> cache_data_;  // Arc iterator data in cache.
1227   mutable ArcIteratorData<Arc> local_data_;  // Arc iterator data in local FST.
1228
1229   mutable const Arc *arcs_;     // Array of arcs.
1230   mutable uint32 data_flags_;   // Arc value flags valid for data in arcs_.
1231   mutable Arc final_arc_;       // Final arc (when required).
1232   mutable uint32 final_flags_;  // Arc value flags valid for final_arc_.
1233
1234   ArcIterator(const ArcIterator &) = delete;
1235   ArcIterator &operator=(const ArcIterator &) = delete;
1236 };
1237
1238 template <class Arc, class StateTable, class CacheStore>
1239 class ReplaceFstMatcher : public MatcherBase<Arc> {
1240  public:
1241   using Label = typename Arc::Label;
1242   using StateId = typename Arc::StateId;
1243   using Weight = typename Arc::Weight;
1244
1245   using FST = ReplaceFst<Arc, StateTable, CacheStore>;
1246   using LocalMatcher = MultiEpsMatcher<Matcher<Fst<Arc>>>;
1247
1248   using StateTuple = typename StateTable::StateTuple;
1249
1250   // This makes a copy of the FST.
1251   ReplaceFstMatcher(const ReplaceFst<Arc, StateTable, CacheStore> &fst,
1252                     MatchType match_type)
1253       : owned_fst_(fst.Copy()),
1254         fst_(*owned_fst_),
1255         impl_(fst_.GetMutableImpl()),
1256         s_(fst::kNoStateId),
1257         match_type_(match_type),
1258         current_loop_(false),
1259         final_arc_(false),
1260         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1261     if (match_type_ == fst::MATCH_OUTPUT) {
1262       std::swap(loop_.ilabel, loop_.olabel);
1263     }
1264     InitMatchers();
1265   }
1266
1267   // This doesn't copy the FST.
1268   ReplaceFstMatcher(const ReplaceFst<Arc, StateTable, CacheStore> *fst,
1269                     MatchType match_type)
1270       : fst_(*fst),
1271         impl_(fst_.GetMutableImpl()),
1272         s_(fst::kNoStateId),
1273         match_type_(match_type),
1274         current_loop_(false),
1275         final_arc_(false),
1276         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1277     if (match_type_ == fst::MATCH_OUTPUT) {
1278       std::swap(loop_.ilabel, loop_.olabel);
1279     }
1280     InitMatchers();
1281   }
1282
1283   // This makes a copy of the FST.
1284   ReplaceFstMatcher(
1285       const ReplaceFstMatcher<Arc, StateTable, CacheStore> &matcher,
1286       bool safe = false)
1287       : owned_fst_(matcher.fst_.Copy(safe)),
1288         fst_(*owned_fst_),
1289         impl_(fst_.GetMutableImpl()),
1290         s_(fst::kNoStateId),
1291         match_type_(matcher.match_type_),
1292         current_loop_(false),
1293         final_arc_(false),
1294         loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) {
1295     if (match_type_ == fst::MATCH_OUTPUT) {
1296       std::swap(loop_.ilabel, loop_.olabel);
1297     }
1298     InitMatchers();
1299   }
1300
1301   // Creates a local matcher for each component FST in the RTN. LocalMatcher is
1302   // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each
1303   // non-terminal arc, since these non-terminal
1304   // turn into epsilons on recursion.
1305   void InitMatchers() {
1306     const auto &fst_array = impl_->fst_array_;
1307     matcher_.resize(fst_array.size());
1308     for (Label i = 0; i < fst_array.size(); ++i) {
1309       if (fst_array[i]) {
1310         matcher_[i].reset(
1311             new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList));
1312         auto it = impl_->nonterminal_set_.begin();
1313         for (; it != impl_->nonterminal_set_.end(); ++it) {
1314           matcher_[i]->AddMultiEpsLabel(*it);
1315         }
1316       }
1317     }
1318   }
1319
1320   ReplaceFstMatcher<Arc, StateTable, CacheStore> *Copy(
1321       bool safe = false) const override {
1322     return new ReplaceFstMatcher<Arc, StateTable, CacheStore>(*this, safe);
1323   }
1324
1325   MatchType Type(bool test) const override {
1326     if (match_type_ == MATCH_NONE) return match_type_;
1327     const auto true_prop =
1328         match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
1329     const auto false_prop =
1330         match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
1331     const auto props = fst_.Properties(true_prop | false_prop, test);
1332     if (props & true_prop) {
1333       return match_type_;
1334     } else if (props & false_prop) {
1335       return MATCH_NONE;
1336     } else {
1337       return MATCH_UNKNOWN;
1338     }
1339   }
1340
1341   const Fst<Arc> &GetFst() const override { return fst_; }
1342
1343   uint64 Properties(uint64 props) const override { return props; }
1344
1345   // Sets the state from which our matching happens.
1346   void SetState(StateId s) final {
1347     if (s_ == s) return;
1348     s_ = s;
1349     tuple_ = impl_->GetStateTable()->Tuple(s_);
1350     if (tuple_.fst_state == kNoStateId) {
1351       done_ = true;
1352       return;
1353     }
1354     // Gets current matcher, used for non-epsilon matching.
1355     current_matcher_ = matcher_[tuple_.fst_id].get();
1356     current_matcher_->SetState(tuple_.fst_state);
1357     loop_.nextstate = s_;
1358     final_arc_ = false;
1359   }
1360
1361   // Searches for label from previous set state. If label == 0, first
1362   // hallucinate an epsilon loop; otherwise use the underlying matcher to
1363   // search for the label or epsilons. Note since the ReplaceFst recursion
1364   // on non-terminal arcs causes epsilon transitions to be created we use
1365   // MultiEpsilonMatcher to search for possible matches of non-terminals. If the
1366   // component FST
1367   // reaches a final state we also need to add the exiting final arc.
1368   bool Find(Label label) final {
1369     bool found = false;
1370     label_ = label;
1371     if (label_ == 0 || label_ == kNoLabel) {
1372       // Computes loop directly, avoiding Replace::ComputeArc.
1373       if (label_ == 0) {
1374         current_loop_ = true;
1375         found = true;
1376       }
1377       // Searches for matching multi-epsilons.
1378       final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr);
1379       found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1380     } else {
1381       // Searches on a sub machine directly using sub machine matcher.
1382       found = current_matcher_->Find(label_);
1383     }
1384     return found;
1385   }
1386
1387   bool Done() const final {
1388     return !current_loop_ && !final_arc_ && current_matcher_->Done();
1389   }
1390
1391   const Arc &Value() const final {
1392     if (current_loop_) return loop_;
1393     if (final_arc_) {
1394       impl_->ComputeFinalArc(tuple_, &arc_);
1395       return arc_;
1396     }
1397     const auto &component_arc = current_matcher_->Value();
1398     impl_->ComputeArc(tuple_, component_arc, &arc_);
1399     return arc_;
1400   }
1401
1402   void Next() final {
1403     if (current_loop_) {
1404       current_loop_ = false;
1405       return;
1406     }
1407     if (final_arc_) {
1408       final_arc_ = false;
1409       return;
1410     }
1411     current_matcher_->Next();
1412   }
1413
1414   ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
1415
1416  private:
1417   std::unique_ptr<const ReplaceFst<Arc, StateTable, CacheStore>> owned_fst_;
1418   const ReplaceFst<Arc, StateTable, CacheStore> &fst_;
1419   internal::ReplaceFstImpl<Arc, StateTable, CacheStore> *impl_;
1420   LocalMatcher *current_matcher_;
1421   std::vector<std::unique_ptr<LocalMatcher>> matcher_;
1422   StateId s_;             // Current state.
1423   Label label_;           // Current label.
1424   MatchType match_type_;  // Supplied by caller.
1425   mutable bool done_;
1426   mutable bool current_loop_;  // Current arc is the implicit loop.
1427   mutable bool final_arc_;     // Current arc for exiting recursion.
1428   mutable StateTuple tuple_;   // Tuple corresponding to state_.
1429   mutable Arc arc_;
1430   Arc loop_;
1431
1432   ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete;
1433 };
1434
1435 template <class Arc, class StateTable, class CacheStore>
1436 inline void ReplaceFst<Arc, StateTable, CacheStore>::InitStateIterator(
1437     StateIteratorData<Arc> *data) const {
1438   data->base =
1439       new StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(*this);
1440 }
1441
1442 using StdReplaceFst = ReplaceFst<StdArc>;
1443
1444 // Recursively replaces arcs in the root FSTs with other FSTs.
1445 // This version writes the result of replacement to an output MutableFst.
1446 //
1447 // Replace supports replacement of arcs in one Fst with another FST. This
1448 // replacement is recursive. Replace takes an array of FST(s). One FST
1449 // represents the root (or topology) machine. The root FST refers to other FSTs
1450 // by recursively replacing arcs labeled as non-terminals with the matching
1451 // non-terminal FST. Currently Replace uses the output symbols of the arcs to
1452 // determine whether the arc is a non-terminal arc or not. A non-terminal can be
1453 // any label that is not a non-zero terminal label in the output alphabet.
1454 //
1455 // Note that input argument is a vector of pairs. These correspond to the tuple
1456 // of non-terminal Label and corresponding FST.
1457 template <class Arc>
1458 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1459                  &ifst_array,
1460              MutableFst<Arc> *ofst,
1461              ReplaceFstOptions<Arc> opts = ReplaceFstOptions<Arc>()) {
1462   opts.gc = true;
1463   opts.gc_limit = 0;  // Caches only the last state for fastest copy.
1464   *ofst = ReplaceFst<Arc>(ifst_array, opts);
1465 }
1466
1467 template <class Arc>
1468 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1469                  &ifst_array,
1470              MutableFst<Arc> *ofst, const ReplaceUtilOptions &opts) {
1471   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1472 }
1473
1474 // For backwards compatibility.
1475 template <class Arc>
1476 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1477                  &ifst_array,
1478              MutableFst<Arc> *ofst, typename Arc::Label root,
1479              bool epsilon_on_replace) {
1480   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1481 }
1482
1483 template <class Arc>
1484 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1485                  &ifst_array,
1486              MutableFst<Arc> *ofst, typename Arc::Label root) {
1487   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1488 }
1489
1490 }  // namespace fst
1491
1492 #endif  // FST_REPLACE_H_