Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / compose.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to compute the composition of two FSTs.
5
6 #ifndef FST_COMPOSE_H_
7 #define FST_COMPOSE_H_
8
9 #include <algorithm>
10
11 #include <fst/log.h>
12
13 #include <fst/cache.h>
14 #include <fst/compose-filter.h>
15 #include <fst/fst-decl.h>  // For optional argument declarations
16 #include <fst/lookahead-filter.h>
17 #include <fst/matcher.h>
18 #include <fst/state-table.h>
19 #include <fst/test-properties.h>
20
21
22 namespace fst {
23
24 // Delayed composition options templated on the arc type, the matcher,
25 // the composition filter, and the composition state table.  By
26 // default, the matchers, filter, and state table are constructed by
27 // composition. If set below, the user can instead pass in these
28 // objects; in that case, ComposeFst takes their ownership. This
29 // version controls composition implemented between generic Fst<Arc>
30 // types and a shared matcher type M for Fst<Arc>. This should be
31 // adequate for most applications, giving a reasonable tradeoff
32 // between efficiency and code sharing (but see ComposeFstImplOptions).
33 template <class Arc, class M = Matcher<Fst<Arc>>,
34           class Filter = SequenceComposeFilter<M>,
35           class StateTable =
36               GenericComposeStateTable<Arc, typename Filter::FilterState>>
37 struct ComposeFstOptions : public CacheOptions {
38   M *matcher1;              // FST1 matcher.
39   M *matcher2;              // FST2 matcher.
40   Filter *filter;           // Composition filter.
41   StateTable *state_table;  // Composition state table.
42
43   explicit ComposeFstOptions(const CacheOptions &opts = CacheOptions(),
44                              M *matcher1 = nullptr, M *matcher2 = nullptr,
45                              Filter *filter = nullptr,
46                              StateTable *state_table = nullptr)
47       : CacheOptions(opts),
48         matcher1(matcher1),
49         matcher2(matcher2),
50         filter(filter),
51         state_table(state_table) {}
52 };
53
54 // Forward declaration of ComposeFstMatcher.
55 template <class C, class F, class T>
56 class ComposeFstMatcher;
57
58 // Delayed composition options templated on the two matcher types, the
59 // composition filter, the composition state table and the cache store. By
60 // default, the matchers, filter, state table and cache store are constructed
61 // by composition. If set below, the user can instead pass in these objects; in
62 // that case, ComposeFst takes their ownership. This version controls
63 // composition implemented using arbitrary matchers (of the same arc type but
64 // otherwise arbitrary FST type). The user must ensure the matchers are
65 // compatible. These options permit the most efficient use, but shares the
66 // least code. This is for advanced use only in the most demanding or
67 // specialized applications that can benefit from it; otherwise, prefer
68 // ComposeFstOptions).
69 template <class M1, class M2, class Filter = SequenceComposeFilter<M1, M2>,
70           class StateTable = GenericComposeStateTable<
71               typename M1::Arc, typename Filter::FilterState>,
72           class CacheStore = DefaultCacheStore<typename M1::Arc>>
73 struct ComposeFstImplOptions : public CacheImplOptions<CacheStore> {
74   M1 *matcher1;    // FST1 matcher (see matcher.h)....
75   M2 *matcher2;    // FST2 matcher.
76   Filter *filter;  // Composition filter (see compose-filter.h).
77   StateTable
78     *state_table;        // Composition state table (see compose-state-table.h).
79   bool own_state_table;   // ComposeFstImpl takes ownership of 'state_table'?
80   bool allow_noncommute;  // Allow non-commutative weights
81
82   explicit ComposeFstImplOptions(const CacheOptions &opts,
83                                  M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
84                                  Filter *filter = nullptr,
85                                  StateTable *state_table = nullptr)
86       : CacheImplOptions<CacheStore>(opts),
87         matcher1(matcher1),
88         matcher2(matcher2),
89         filter(filter),
90         state_table(state_table),
91         own_state_table(true),
92         allow_noncommute(false) {}
93
94   explicit ComposeFstImplOptions(const CacheImplOptions<CacheStore> &opts,
95                                  M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
96                                  Filter *filter = nullptr,
97                                  StateTable *state_table = nullptr)
98       : CacheImplOptions<CacheStore>(opts),
99         matcher1(matcher1),
100         matcher2(matcher2),
101         filter(filter),
102         state_table(state_table),
103         own_state_table(true),
104         allow_noncommute(false) {}
105
106   ComposeFstImplOptions()
107       : matcher1(nullptr),
108         matcher2(nullptr),
109         filter(nullptr),
110         state_table(nullptr),
111         own_state_table(true),
112         allow_noncommute(false) {}
113 };
114
115 namespace internal {
116
117 // Implementation of delayed composition. This base class is common to the
118 // variants with different matchers, composition filters and state tables.
119 template <class Arc, class CacheStore = DefaultCacheStore<Arc>,
120           class F = ComposeFst<Arc, CacheStore>>
121 class ComposeFstImplBase
122     : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
123  public:
124   using FST = F;
125   using Label = typename Arc::Label;
126   using StateId = typename Arc::StateId;
127   using Weight = typename Arc::Weight;
128
129   using State = typename CacheStore::State;
130   using CacheImpl = CacheBaseImpl<State, CacheStore>;
131
132   using FstImpl<Arc>::SetType;
133   using FstImpl<Arc>::SetProperties;
134   using FstImpl<Arc>::Properties;
135   using FstImpl<Arc>::SetInputSymbols;
136   using FstImpl<Arc>::SetOutputSymbols;
137
138   using CacheImpl::HasStart;
139   using CacheImpl::HasFinal;
140   using CacheImpl::HasArcs;
141   using CacheImpl::SetFinal;
142   using CacheImpl::SetStart;
143
144   ComposeFstImplBase(const CacheImplOptions<CacheStore> &opts)
145       : CacheImpl(opts) {}
146
147   ComposeFstImplBase(const CacheOptions &opts) : CacheImpl(opts) {}
148
149   ComposeFstImplBase(const ComposeFstImplBase &impl) : CacheImpl(impl, true) {
150     SetType(impl.Type());
151     SetProperties(impl.Properties(), kCopyProperties);
152     SetInputSymbols(impl.InputSymbols());
153     SetOutputSymbols(impl.OutputSymbols());
154   }
155
156   virtual ComposeFstImplBase *Copy() const = 0;
157
158   ~ComposeFstImplBase() override {}
159
160   StateId Start() {
161     if (!HasStart()) {
162       const auto start = ComputeStart();
163       if (start != kNoStateId) SetStart(start);
164     }
165     return CacheImpl::Start();
166   }
167
168   Weight Final(StateId s) {
169     if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
170     return CacheImpl::Final(s);
171   }
172
173   virtual void Expand(StateId s) = 0;
174
175   size_t NumArcs(StateId s) {
176     if (!HasArcs(s)) Expand(s);
177     return CacheImpl::NumArcs(s);
178   }
179
180   size_t NumInputEpsilons(StateId s) {
181     if (!HasArcs(s)) Expand(s);
182     return CacheImpl::NumInputEpsilons(s);
183   }
184
185   size_t NumOutputEpsilons(StateId s) {
186     if (!HasArcs(s)) Expand(s);
187     return CacheImpl::NumOutputEpsilons(s);
188   }
189
190   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
191     if (!HasArcs(s)) Expand(s);
192     CacheImpl::InitArcIterator(s, data);
193   }
194
195   virtual MatcherBase<Arc> *InitMatcher(const F &fst,
196                                         MatchType match_type) const {
197     // Use the default matcher if no override is provided.
198     return nullptr;
199   }
200
201  protected:
202   virtual StateId ComputeStart() = 0;
203   virtual Weight ComputeFinal(StateId s) = 0;
204 };
205
206 // Implementation of delayed composition templated on the matchers (see
207 // matcher.h), composition filter (see compose-filter.h) and the composition
208 // state table (see compose-state-table.h).
209 template <class CacheStore, class Filter, class StateTable>
210 class ComposeFstImpl
211     : public ComposeFstImplBase<typename CacheStore::Arc, CacheStore> {
212  public:
213   using Matcher1 = typename Filter::Matcher1;
214   using Matcher2 = typename Filter::Matcher2;
215
216   using FST1 = typename Matcher1::FST;
217   using FST2 = typename Matcher2::FST;
218
219   using Arc = typename CacheStore::Arc;
220   using Label = typename Arc::Label;
221   using StateId = typename Arc::StateId;
222   using Weight = typename Arc::Weight;
223
224   using FilterState = typename Filter::FilterState;
225   using State = typename CacheStore::State;
226
227   using CacheImpl = CacheBaseImpl<State, CacheStore>;
228
229   using StateTuple = typename StateTable::StateTuple;
230
231   friend class ComposeFstMatcher<CacheStore, Filter, StateTable>;
232
233   using FstImpl<Arc>::SetInputSymbols;
234   using FstImpl<Arc>::SetOutputSymbols;
235   using FstImpl<Arc>::SetType;
236   using FstImpl<Arc>::SetProperties;
237
238   template <class M1, class M2>
239   ComposeFstImpl(const FST1 &fst1, const FST2 &fst2,
240                  const ComposeFstImplOptions<M1, M2, Filter, StateTable,
241                                              CacheStore> &opts);
242
243   ComposeFstImpl(const ComposeFstImpl &impl)
244       : ComposeFstImplBase<Arc, CacheStore>(impl),
245         filter_(new Filter(*impl.filter_, true)),
246         matcher1_(filter_->GetMatcher1()),
247         matcher2_(filter_->GetMatcher2()),
248         fst1_(matcher1_->GetFst()),
249         fst2_(matcher2_->GetFst()),
250         state_table_(new StateTable(*impl.state_table_)),
251         own_state_table_(true),
252         match_type_(impl.match_type_) {}
253
254   ~ComposeFstImpl() override {
255     if (own_state_table_) delete state_table_;
256   }
257
258   ComposeFstImpl *Copy() const override { return new ComposeFstImpl(*this); }
259
260   uint64 Properties() const override { return Properties(kFstProperties); }
261
262   // Sets error if found, and returns other FST impl properties.
263   uint64 Properties(uint64 mask) const override {
264     if ((mask & kError) &&
265         (fst1_.Properties(kError, false) || fst2_.Properties(kError, false) ||
266          (matcher1_->Properties(0) & kError) ||
267          (matcher2_->Properties(0) & kError) |
268              (filter_->Properties(0) & kError) ||
269          state_table_->Error())) {
270       SetProperties(kError, kError);
271     }
272     return FstImpl<Arc>::Properties(mask);
273   }
274
275   // Arranges it so that the first arg to OrderedExpand is the Fst
276   // that will be matched on.
277   void Expand(StateId s) override {
278     const auto &tuple = state_table_->Tuple(s);
279     const auto s1 = tuple.StateId1();
280     const auto s2 = tuple.StateId2();
281     filter_->SetState(s1, s2, tuple.GetFilterState());
282     if (MatchInput(s1, s2)) {
283       OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true);
284     } else {
285       OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_, false);
286     }
287   }
288
289   const FST1 &GetFst1() const { return fst1_; }
290
291   const FST2 &GetFst2() const { return fst2_; }
292
293   const Matcher1 *GetMatcher1() const { return matcher1_; }
294
295   Matcher1 *GetMatcher1() { return matcher1_; }
296
297   const Matcher2 *GetMatcher2() const { return matcher2_; }
298
299   Matcher2 *GetMatcher2() { return matcher2_; }
300
301   const Filter *GetFilter() const { return filter_.get(); }
302
303   Filter *GetFilter() { return filter_.get(); }
304
305   const StateTable *GetStateTable() const { return state_table_; }
306
307   StateTable *GetStateTable() { return state_table_; }
308
309   MatcherBase<Arc> *InitMatcher(const ComposeFst<Arc, CacheStore> &fst,
310                                 MatchType match_type) const override {
311     const auto test_props = match_type == MATCH_INPUT
312                                 ? kFstProperties & ~kILabelInvariantProperties
313                                 : kFstProperties & ~kOLabelInvariantProperties;
314     // If both matchers support 'match_type' and we have a guarantee that a
315     // call to 'filter_->FilterArc(arc1, arc2)' will not modify the ilabel of
316     // arc1 when MATCH_INPUT or the olabel or arc2 when MATCH_OUTPUT, then
317     // ComposeFstMatcher can be used.
318     if ((matcher1_->Type(false) == match_type) &&
319         (matcher2_->Type(false) == match_type) &&
320         (filter_->Properties(test_props) == test_props)) {
321       return new ComposeFstMatcher<
322         CacheStore, Filter, StateTable>(&fst, match_type);
323     }
324     return nullptr;
325   }
326
327  private:
328   // This does that actual matching of labels in the composition. The
329   // arguments are ordered so matching is called on state 'sa' of
330   // 'fsta' for each arc leaving state 'sb' of 'fstb'. The 'match_input' arg
331   // determines whether the input or output label of arcs at 'sb' is
332   // the one to match on.
333   template <class FST, class Matcher>
334   void OrderedExpand(StateId s, const Fst<Arc> &, StateId sa, const FST &fstb,
335                      StateId sb, Matcher *matchera, bool match_input) {
336     matchera->SetState(sa);
337     // First processes non-consuming symbols (e.g., epsilons) on FSTA.
338     const Arc loop(match_input ? 0 : kNoLabel, match_input ? kNoLabel : 0,
339                    Weight::One(), sb);
340     MatchArc(s, matchera, loop, match_input);
341     // Then processes matches on FSTB.
342     for (ArcIterator<FST> iterb(fstb, sb); !iterb.Done(); iterb.Next()) {
343       MatchArc(s, matchera, iterb.Value(), match_input);
344     }
345     CacheImpl::SetArcs(s);
346   }
347
348   // Matches a single transition from 'fstb' against 'fata' at 's'.
349   template <class Matcher>
350   void MatchArc(StateId s, Matcher *matchera, const Arc &arc,
351                 bool match_input) {
352     if (matchera->Find(match_input ? arc.olabel : arc.ilabel)) {
353       for (; !matchera->Done(); matchera->Next()) {
354         auto arca = matchera->Value();
355         auto arcb = arc;
356         if (match_input) {
357           const auto &fs = filter_->FilterArc(&arcb, &arca);
358           if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs);
359         } else {
360           const auto &fs = filter_->FilterArc(&arca, &arcb);
361           if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs);
362         }
363       }
364     }
365   }
366
367   // Add a matching transition at 's'.
368   void AddArc(StateId s, const Arc &arc1, const Arc &arc2,
369               const FilterState &f) {
370     const StateTuple tuple(arc1.nextstate, arc2.nextstate, f);
371     const Arc oarc(arc1.ilabel, arc2.olabel, Times(arc1.weight, arc2.weight),
372                    state_table_->FindState(tuple));
373     CacheImpl::PushArc(s, oarc);
374   }
375
376   StateId ComputeStart() override {
377     const auto s1 = fst1_.Start();
378     if (s1 == kNoStateId) return kNoStateId;
379     const auto s2 = fst2_.Start();
380     if (s2 == kNoStateId) return kNoStateId;
381     const auto &fs = filter_->Start();
382     const StateTuple tuple(s1, s2, fs);
383     return state_table_->FindState(tuple);
384   }
385
386   Weight ComputeFinal(StateId s) override {
387     const auto &tuple = state_table_->Tuple(s);
388     const auto s1 = tuple.StateId1();
389     auto final1 = matcher1_->Final(s1);
390     if (final1 == Weight::Zero()) return final1;
391     const auto s2 = tuple.StateId2();
392     auto final2 = matcher2_->Final(s2);
393     if (final2 == Weight::Zero()) return final2;
394     filter_->SetState(s1, s2, tuple.GetFilterState());
395     filter_->FilterFinal(&final1, &final2);
396     return Times(final1, final2);
397   }
398
399   // Determines which side to match on per composition state.
400   bool MatchInput(StateId s1, StateId s2) {
401     switch (match_type_) {
402       case MATCH_INPUT:
403         return true;
404       case MATCH_OUTPUT:
405         return false;
406       default:  // MATCH_BOTH
407         const auto priority1 = matcher1_->Priority(s1);
408         const auto priority2 = matcher2_->Priority(s2);
409         if (priority1 == kRequirePriority && priority2 == kRequirePriority) {
410           FSTERROR() << "ComposeFst: Both sides can't require match";
411           SetProperties(kError, kError);
412           return true;
413         }
414         if (priority1 == kRequirePriority) return false;
415         if (priority2 == kRequirePriority) {
416           return true;
417         }
418         return priority1 <= priority2;
419     }
420   }
421
422   // Identifies and verifies the capabilities of the matcher to be used for
423   // composition.
424   void SetMatchType();
425
426   std::unique_ptr<Filter> filter_;
427   Matcher1 *matcher1_;  // Borrowed reference.
428   Matcher2 *matcher2_;  // Borrowed reference.
429   const FST1 &fst1_;
430   const FST2 &fst2_;
431   StateTable *state_table_;
432   bool own_state_table_;
433
434   MatchType match_type_;
435 };
436
437 template <class CacheStore, class Filter, class StateTable>
438 template <class M1, class M2>
439 ComposeFstImpl<CacheStore, Filter, StateTable>::ComposeFstImpl(
440     const FST1 &fst1, const FST2 &fst2,
441     const ComposeFstImplOptions<M1, M2, Filter, StateTable, CacheStore> &opts)
442     : ComposeFstImplBase<Arc, CacheStore>(opts),
443       filter_(opts.filter
444                   ? opts.filter
445                   : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)),
446       matcher1_(filter_->GetMatcher1()),
447       matcher2_(filter_->GetMatcher2()),
448       fst1_(matcher1_->GetFst()),
449       fst2_(matcher2_->GetFst()),
450       state_table_(opts.state_table ? opts.state_table
451                                     : new StateTable(fst1_, fst2_)),
452       own_state_table_(opts.state_table ? opts.own_state_table : true) {
453   SetType("compose");
454   if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) {
455     FSTERROR() << "ComposeFst: Output symbol table of 1st argument "
456                << "does not match input symbol table of 2nd argument";
457     SetProperties(kError, kError);
458   }
459   SetInputSymbols(fst1_.InputSymbols());
460   SetOutputSymbols(fst2_.OutputSymbols());
461   SetMatchType();
462   VLOG(2) << "ComposeFstImpl: Match type: " << match_type_;
463   if (match_type_ == MATCH_NONE) SetProperties(kError, kError);
464   const auto fprops1 = fst1.Properties(kFstProperties, false);
465   const auto fprops2 = fst2.Properties(kFstProperties, false);
466   const auto mprops1 = matcher1_->Properties(fprops1);
467   const auto mprops2 = matcher2_->Properties(fprops2);
468   const auto cprops = ComposeProperties(mprops1, mprops2);
469   SetProperties(filter_->Properties(cprops), kCopyProperties);
470   if (state_table_->Error()) SetProperties(kError, kError);
471 }
472
473 template <class CacheStore, class Filter, class StateTable>
474 void ComposeFstImpl<CacheStore, Filter, StateTable>::SetMatchType() {
475   // Ensures any required matching is possible and known.
476   if ((matcher1_->Flags() & kRequireMatch) &&
477       matcher1_->Type(true) != MATCH_OUTPUT) {
478     FSTERROR() << "ComposeFst: 1st argument cannot perform required matching "
479                << "(sort?).";
480     match_type_ = MATCH_NONE;
481     return;
482   }
483   if ((matcher2_->Flags() & kRequireMatch) &&
484       matcher2_->Type(true) != MATCH_INPUT) {
485     FSTERROR() << "ComposeFst: 2nd argument cannot perform required matching "
486                << "(sort?).";
487     match_type_ = MATCH_NONE;
488     return;
489   }
490   // Finds which sides to match on (favoring minimal testing of capabilities).
491   const auto type1 = matcher1_->Type(false);
492   const auto type2 = matcher2_->Type(false);
493   if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) {
494     match_type_ = MATCH_BOTH;
495   } else if (type1 == MATCH_OUTPUT) {
496     match_type_ = MATCH_OUTPUT;
497   } else if (type2 == MATCH_INPUT) {
498     match_type_ = MATCH_INPUT;
499   } else if (matcher1_->Type(true) == MATCH_OUTPUT) {
500     match_type_ = MATCH_OUTPUT;
501   } else if (matcher2_->Type(true) == MATCH_INPUT) {
502     match_type_ = MATCH_INPUT;
503   } else {
504     FSTERROR() << "ComposeFst: 1st argument cannot match on output labels "
505                << "and 2nd argument cannot match on input labels (sort?).";
506     match_type_ = MATCH_NONE;
507   }
508 }
509
510 }  // namespace internal
511
512 // Computes the composition of two transducers. This version is a delayed FST.
513 // If FST1 transduces string x to y with weight a and FST2 transduces y to z
514 // with weight b, then their composition transduces string x to z with weight
515 // Times(x, z).
516 //
517 // The output labels of the first transducer or the input labels of the second
518 // transducer must be sorted (with the default matcher). The weights need to
519 // form a commutative semiring (valid for TropicalWeight and LogWeight).
520 //
521 // Complexity:
522 //
523 // Assuming the first FST is unsorted and the second is sorted,
524 //
525 //   Time: O(v1 v2 d1 (log d2 + m2)),
526 //   Space: O(v1 v2)
527 //
528 // where vi = # of states visited, di = maximum out-degree, and mi the
529 // maximum multiplicity of the states visited, for the ith FST. Constant time
530 // and space to visit an input state or arc is assumed and exclusive of caching.
531 //
532 // Caveats:
533 // - ComposeFst does not trim its output (since it is a delayed operation).
534 // - The efficiency of composition can be strongly affected by several factors:
535 //   - the choice of which transducer is sorted - prefer sorting the FST
536 //     that has the greater average out-degree.
537 //   - the amount of non-determinism
538 //   - the presence and location of epsilon transitions - avoid epsilon
539 //     transitions on the output side of the first transducer or
540 //     the input side of the second transducer or prefer placing
541 //     them later in a path since they delay matching and can
542 //     introduce non-coaccessible states and transitions.
543 //
544 // This class attaches interface to implementation and handles reference
545 // counting, delegating most methods to ImplToFst. The CacheStore specifies the
546 // cache store (default declared in fst-decl.h).
547 template <class A, class CacheStore /* = DefaultCacheStore<A> */>
548 class ComposeFst
549     : public ImplToFst<internal::ComposeFstImplBase<A, CacheStore>> {
550  public:
551   using Arc = A;
552   using StateId = typename Arc::StateId;
553   using Weight = typename Arc::Weight;
554
555   using Store = CacheStore;
556   using State = typename CacheStore::State;
557
558   using Impl = internal::ComposeFstImplBase<A, CacheStore>;
559
560   friend class ArcIterator<ComposeFst<Arc, CacheStore>>;
561   friend class StateIterator<ComposeFst<Arc, CacheStore>>;
562   template <class, class, class> friend class ComposeFstMatcher;
563
564   // Compose specifying only caching options.
565   ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
566              const CacheOptions &opts = CacheOptions())
567       : ImplToFst<Impl>(CreateBase(fst1, fst2, opts)) {}
568
569   // Compose specifying one shared matcher type M. Requires that the input FSTs
570   // and matcher FST types be Fst<Arc>. Recommended for best code-sharing and
571   // matcher compatiblity.
572   template <class Matcher, class Filter, class StateTuple>
573   ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
574              const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts)
575       : ImplToFst<Impl>(CreateBase1(fst1, fst2, opts)) {}
576
577   // Compose specifying two matcher types Matcher1 and Matcher2. Requires input
578   // FST (of the same Arc type, but o.w. arbitrary) match the corresponding
579   // matcher FST types). Recommended only for advanced use in demanding or
580   // specialized applications due to potential code bloat and matcher
581   // incompatibilities.
582   template <class Matcher1, class Matcher2, class Filter, class StateTuple>
583   ComposeFst(const typename Matcher1::FST &fst1,
584              const typename Matcher2::FST &fst2,
585              const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
586                                          CacheStore> &opts)
587       : ImplToFst<Impl>(CreateBase2(fst1, fst2, opts)) {}
588
589   // See Fst<>::Copy() for doc.
590   ComposeFst(const ComposeFst<A, CacheStore> &fst, bool safe = false)
591       : ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
592                              : fst.GetSharedImpl()) {}
593
594   // Get a copy of this ComposeFst. See Fst<>::Copy() for further doc.
595   ComposeFst<A, CacheStore> *Copy(bool safe = false) const override {
596     return new ComposeFst<A, CacheStore>(*this, safe);
597   }
598
599   inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
600
601   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
602     GetMutableImpl()->InitArcIterator(s, data);
603   }
604
605   MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
606     return GetImpl()->InitMatcher(*this, match_type);
607   }
608
609  protected:
610   using ImplToFst<Impl>::GetImpl;
611   using ImplToFst<Impl>::GetMutableImpl;
612
613   explicit ComposeFst(std::shared_ptr<Impl> impl) : ImplToFst<Impl>(impl) {}
614
615   // Create compose implementation specifying two matcher types.
616   template <class Matcher1, class Matcher2, class Filter, class StateTuple>
617   static std::shared_ptr<Impl> CreateBase2(
618       const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2,
619       const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
620                                   CacheStore> &opts) {
621     auto impl = std::make_shared<
622         internal::ComposeFstImpl<CacheStore, Filter, StateTuple>>(fst1, fst2,
623                                                                   opts);
624     if (!(Weight::Properties() & kCommutative) && !opts.allow_noncommute) {
625       const auto props1 = fst1.Properties(kUnweighted, true);
626       const auto props2 = fst2.Properties(kUnweighted, true);
627       if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) {
628         FSTERROR() << "ComposeFst: Weights must be a commutative semiring: "
629                    << Weight::Type();
630         impl->SetProperties(kError, kError);
631       }
632     }
633     return impl;
634   }
635
636   // Create compose implementation specifying one matcher type; requires that
637   // input and matcher FST types be Fst<Arc>.
638   template <class Matcher, class Filter, class StateTuple>
639   static std::shared_ptr<Impl> CreateBase1(
640       const Fst<Arc> &fst1, const Fst<Arc> &fst2,
641       const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts) {
642     ComposeFstImplOptions<Matcher, Matcher, Filter, StateTuple, CacheStore>
643         nopts(opts, opts.matcher1, opts.matcher2, opts.filter,
644               opts.state_table);
645     return CreateBase2(fst1, fst2, nopts);
646   }
647
648   // Create compose implementation specifying no matcher type.
649   static std::shared_ptr<Impl> CreateBase(const Fst<Arc> &fst1,
650                                           const Fst<Arc> &fst2,
651                                           const CacheOptions &opts) {
652     switch (LookAheadMatchType(fst1, fst2)) {  // Check for lookahead matchers
653       default:
654       case MATCH_NONE: {  // Default composition (no look-ahead).
655         ComposeFstOptions<Arc> nopts(opts);
656         return CreateBase1(fst1, fst2, nopts);
657       }
658       case MATCH_OUTPUT: {  // Lookahead on fst1.
659         using M = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::FstMatcher;
660         using F = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::ComposeFilter;
661         ComposeFstOptions<Arc, M, F> nopts(opts);
662         return CreateBase1(fst1, fst2, nopts);
663       }
664       case MATCH_INPUT: {  // Lookahead on fst2
665         using M = typename DefaultLookAhead<Arc, MATCH_INPUT>::FstMatcher;
666         using F = typename DefaultLookAhead<Arc, MATCH_INPUT>::ComposeFilter;
667         ComposeFstOptions<Arc, M, F> nopts(opts);
668         return CreateBase1(fst1, fst2, nopts);
669       }
670     }
671   }
672
673  private:
674   ComposeFst &operator=(const ComposeFst &fst) = delete;
675 };
676
677 // Specialization for ComposeFst.
678 template <class Arc, class CacheStore>
679 class StateIterator<ComposeFst<Arc, CacheStore>>
680     : public CacheStateIterator<ComposeFst<Arc, CacheStore>> {
681  public:
682   explicit StateIterator(const ComposeFst<Arc, CacheStore> &fst)
683       : CacheStateIterator<ComposeFst<Arc, CacheStore>>(fst,
684                                                         fst.GetMutableImpl()) {}
685 };
686
687 // Specialization for ComposeFst.
688 template <class Arc, class CacheStore>
689 class ArcIterator<ComposeFst<Arc, CacheStore>>
690     : public CacheArcIterator<ComposeFst<Arc, CacheStore>> {
691  public:
692   using StateId = typename Arc::StateId;
693
694   ArcIterator(const ComposeFst<Arc, CacheStore> &fst, StateId s)
695       : CacheArcIterator<ComposeFst<Arc, CacheStore>>(fst.GetMutableImpl(), s) {
696     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
697   }
698 };
699
700 template <class Arc, class CacheStore>
701 inline void ComposeFst<Arc, CacheStore>::InitStateIterator(
702     StateIteratorData<Arc> *data) const {
703   data->base = new StateIterator<ComposeFst<Arc, CacheStore>>(*this);
704 }
705
706 // Specialized matcher for ComposeFst. Supports MATCH_INPUT or MATCH_OUTPUT,
707 // iff the underlying matchers for the two FSTS being composed support
708 // MATCH_INPUT or MATCH_OUTPUT, respectively.
709 template <class CacheStore, class Filter, class StateTable>
710 class ComposeFstMatcher : public MatcherBase<typename CacheStore::Arc> {
711  public:
712   using Arc = typename CacheStore::Arc;
713   using Label = typename Arc::Label;
714   using StateId = typename Arc::StateId;
715   using Weight = typename Arc::Weight;
716
717   using Matcher1 = typename Filter::Matcher1;
718   using Matcher2 = typename Filter::Matcher2;
719   using FilterState = typename Filter::FilterState;
720
721   using StateTuple = typename StateTable::StateTuple;
722   using Impl = internal::ComposeFstImpl<CacheStore, Filter, StateTable>;
723
724   // The compose FST arg must match the filter and state table types.
725   // This makes a copy of the FST.
726   ComposeFstMatcher(const ComposeFst<Arc, CacheStore> &fst,
727                     MatchType match_type)
728       : owned_fst_(fst.Copy()),
729         fst_(*owned_fst_),
730         impl_(static_cast<const Impl *>(fst_.GetImpl())),
731         s_(kNoStateId),
732         match_type_(match_type),
733         matcher1_(impl_->matcher1_->Copy()),
734         matcher2_(impl_->matcher2_->Copy()),
735         current_loop_(false),
736         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
737     if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
738   }
739
740   // The compose FST arg must match the filter and state table types.
741   // This doesn't copy the FST (although it may copy components).
742   ComposeFstMatcher(const ComposeFst<Arc, CacheStore> *fst,
743                     MatchType match_type)
744       : fst_(*fst),
745         impl_(static_cast<const Impl *>(fst_.GetImpl())),
746         s_(kNoStateId),
747         match_type_(match_type),
748         matcher1_(impl_->matcher1_->Copy()),
749         matcher2_(impl_->matcher2_->Copy()),
750         current_loop_(false),
751         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
752     if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
753   }
754
755   // This makes a copy of the FST.
756   ComposeFstMatcher(
757       const ComposeFstMatcher<CacheStore, Filter, StateTable> &matcher,
758       bool safe = false)
759       : owned_fst_(matcher.fst_.Copy(safe)),
760         fst_(*owned_fst_),
761         impl_(static_cast<const Impl *>(fst_.GetImpl())),
762         s_(kNoStateId),
763         match_type_(matcher.match_type_),
764         matcher1_(matcher.matcher1_->Copy(safe)),
765         matcher2_(matcher.matcher2_->Copy(safe)),
766         current_loop_(false),
767         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
768     if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
769   }
770
771   ComposeFstMatcher<CacheStore, Filter, StateTable> *Copy(
772       bool safe = false) const override {
773     return new ComposeFstMatcher<CacheStore, Filter, StateTable>(*this, safe);
774   }
775
776   MatchType Type(bool test) const override {
777     if ((matcher1_->Type(test) == MATCH_NONE) ||
778         (matcher2_->Type(test) == MATCH_NONE)) {
779       return MATCH_NONE;
780     }
781     if (((matcher1_->Type(test) == MATCH_UNKNOWN) &&
782          (matcher2_->Type(test) == MATCH_UNKNOWN)) ||
783         ((matcher1_->Type(test) == MATCH_UNKNOWN) &&
784          (matcher2_->Type(test) == match_type_)) ||
785         ((matcher1_->Type(test) == match_type_) &&
786          (matcher2_->Type(test) == MATCH_UNKNOWN))) {
787       return MATCH_UNKNOWN;
788     }
789     if ((matcher1_->Type(test) == match_type_) &&
790         (matcher2_->Type(test) == match_type_)) {
791       return match_type_;
792     }
793     return MATCH_NONE;
794   }
795
796   const Fst<Arc> &GetFst() const override { return fst_; }
797
798   uint64 Properties(uint64 inprops) const override {
799     return inprops;
800   }
801
802   void SetState(StateId s) final {
803     if (s_ == s) return;
804     s_ = s;
805     const auto &tuple = impl_->state_table_->Tuple(s);
806     matcher1_->SetState(tuple.StateId1());
807     matcher2_->SetState(tuple.StateId2());
808     loop_.nextstate = s_;
809   }
810
811   bool Find(Label label) final {
812     bool found = false;
813     current_loop_ = false;
814     if (label == 0) {
815       current_loop_ = true;
816       found = true;
817     }
818     if (match_type_ == MATCH_INPUT) {
819       found = found || FindLabel(label, matcher1_.get(), matcher2_.get());
820     } else {  // match_type_ == MATCH_OUTPUT
821       found = found || FindLabel(label, matcher2_.get(), matcher1_.get());
822     }
823     return found;
824   }
825
826   bool Done() const final {
827     return !current_loop_ && matcher1_->Done() && matcher2_->Done();
828   }
829
830   const Arc &Value() const final { return current_loop_ ? loop_ : arc_; }
831
832   void Next() final {
833     if (current_loop_) {
834       current_loop_ = false;
835     } else if (match_type_ == MATCH_INPUT) {
836       FindNext(matcher1_.get(), matcher2_.get());
837     } else {  // match_type_ == MATCH_OUTPUT
838       FindNext(matcher2_.get(), matcher1_.get());
839     }
840   }
841
842   ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
843
844  private:
845   // Processes a match with the filter and creates resulting arc.
846   bool MatchArc(StateId s, Arc arc1,
847                 Arc arc2) {  // FIXME(kbg): copy but not assignment.
848     const auto &fs = impl_->filter_->FilterArc(&arc1, &arc2);
849     if (fs == FilterState::NoState()) return false;
850     const StateTuple tuple(arc1.nextstate, arc2.nextstate, fs);
851     arc_.ilabel = arc1.ilabel;
852     arc_.olabel = arc2.olabel;
853     arc_.weight = Times(arc1.weight, arc2.weight);
854     arc_.nextstate = impl_->state_table_->FindState(tuple);
855     return true;
856   }
857
858   // Finds the first match allowed by the filter.
859   template <class MatcherA, class MatcherB>
860   bool FindLabel(Label label, MatcherA *matchera, MatcherB *matcherb) {
861     if (matchera->Find(label)) {
862       matcherb->Find(match_type_ == MATCH_INPUT ? matchera->Value().olabel
863                                                 : matchera->Value().ilabel);
864       return FindNext(matchera, matcherb);
865     }
866     return false;
867   }
868
869   // Finds the next match allowed by the filter, returning true iff such a
870   // match is found.
871   template <class MatcherA, class MatcherB>
872   bool FindNext(MatcherA *matchera, MatcherB *matcherb) {
873     // State when entering this function:
874     // 'matchera' is pointed to a match x, y for label x, and a match for y was
875     // requested on 'matcherb'.
876     while (!matchera->Done() || !matcherb->Done()) {
877       if (matcherb->Done()) {
878         // If no more matches for y on 'matcherb', moves forward on 'matchera'
879         // until a match x, y' is found such that there is a match for y' on
880         // 'matcherb'.
881         matchera->Next();
882         while (!matchera->Done() &&
883                !matcherb->Find(match_type_ == MATCH_INPUT
884                                    ? matchera->Value().olabel
885                                    : matchera->Value().ilabel)) {
886           matchera->Next();
887         }
888       }
889       while (!matcherb->Done()) {
890         // 'matchera' is pointing to a match x, y' ('arca') and 'matcherb' is
891         // pointing to a match y', z' ('arcb'). If combining these two arcs is
892         // allowed by the filter (hence resulting in an arc x, z') return true.
893         // Position 'matcherb' on the next potential match for y' before
894         // returning.
895         const auto &arca = matchera->Value();
896         const auto &arcb = matcherb->Value();
897         // Position 'matcherb' on the next potential match for y'.
898         matcherb->Next();
899         // Returns true If combining these two arcs is allowed by the filter
900         // (hence resulting in an arc x, z'); otherwise consider next match
901         // for y' on 'matcherb'.
902         if (MatchArc(s_, match_type_ == MATCH_INPUT ? arca : arcb,
903                      match_type_ == MATCH_INPUT ? arcb : arca)) {
904           return true;
905         }
906       }
907     }
908     // Both 'matchera' and 'matcherb' are done, no more match to analyse.
909     return false;
910   }
911
912   std::unique_ptr<const ComposeFst<Arc, CacheStore>> owned_fst_;
913   const ComposeFst<Arc, CacheStore> &fst_;
914   const Impl *impl_;
915   StateId s_;
916   MatchType match_type_;
917   std::unique_ptr<Matcher1> matcher1_;
918   std::unique_ptr<Matcher2> matcher2_;
919   bool current_loop_;
920   Arc loop_;
921   Arc arc_;
922 };
923
924 // Useful alias when using StdArc.
925 using StdComposeFst = ComposeFst<StdArc>;
926
927 enum ComposeFilter {
928   AUTO_FILTER,
929   NULL_FILTER,
930   TRIVIAL_FILTER,
931   SEQUENCE_FILTER,
932   ALT_SEQUENCE_FILTER,
933   MATCH_FILTER
934 };
935
936 struct ComposeOptions {
937   bool connect;               // Connect output?
938   ComposeFilter filter_type;  // Pre-defined filter to use.
939
940   explicit ComposeOptions(bool connect = true,
941                           ComposeFilter filter_type = AUTO_FILTER)
942       : connect(connect), filter_type(filter_type) {}
943 };
944
945 // Computes the composition of two transducers. This version writes
946 // the composed FST into a MutableFst. If FST1 transduces string x to
947 // y with weight a and FST2 transduces y to z with weight b, then
948 // their composition transduces string x to z with weight
949 // Times(x, z).
950 //
951 // The output labels of the first transducer or the input labels of
952 // the second transducer must be sorted.  The weights need to form a
953 // commutative semiring (valid for TropicalWeight and LogWeight).
954 //
955 // Complexity:
956 //
957 // Assuming the first FST is unsorted and the second is sorted:
958 //
959 //   Time: O(V1 V2 D1 (log D2 + M2)),
960 //   Space: O(V1 V2 D1 M2)
961 //
962 // where Vi = # of states, Di = maximum out-degree, and Mi is the maximum
963 // multiplicity, for the ith FST.
964 //
965 // Caveats:
966 //
967 // - Compose trims its output.
968 // - The efficiency of composition can be strongly affected by several factors:
969 //   - the choice of which transducer is sorted - prefer sorting the FST
970 //     that has the greater average out-degree.
971 //   - the amount of non-determinism
972 //   - the presence and location of epsilon transitions - avoid epsilon
973 //     transitions on the output side of the first transducer or
974 //     the input side of the second transducer or prefer placing
975 //     them later in a path since they delay matching and can
976 //     introduce non-coaccessible states and transitions.
977 template <class Arc>
978 void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
979              MutableFst<Arc> *ofst,
980              const ComposeOptions &opts = ComposeOptions()) {
981   using M = Matcher<Fst<Arc>>;
982   // In each case, we cache only the last state for fastest copy.
983   switch (opts.filter_type) {
984     case AUTO_FILTER: {
985       CacheOptions nopts;
986       nopts.gc_limit = 0;
987       *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
988       break;
989     }
990     case NULL_FILTER: {
991       ComposeFstOptions<Arc, M, NullComposeFilter<M>> copts;
992       copts.gc_limit = 0;
993       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
994       break;
995     }
996     case SEQUENCE_FILTER: {
997       ComposeFstOptions<Arc, M, SequenceComposeFilter<M>> copts;
998       copts.gc_limit = 0;
999       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1000       break;
1001     }
1002     case ALT_SEQUENCE_FILTER: {
1003       ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>> copts;
1004       copts.gc_limit = 0;
1005       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1006       break;
1007     }
1008     case MATCH_FILTER: {
1009       ComposeFstOptions<Arc, M, MatchComposeFilter<M>> copts;
1010       copts.gc_limit = 0;
1011       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1012       break;
1013     }
1014     case TRIVIAL_FILTER: {
1015       ComposeFstOptions<Arc, M, TrivialComposeFilter<M>> copts;
1016       copts.gc_limit = 0;
1017       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1018       break;
1019     }
1020   }
1021   if (opts.connect) Connect(ofst);
1022 }
1023
1024 }  // namespace fst
1025
1026 #endif  // FST_COMPOSE_H_