Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / queue.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes for various FST state queues with a unified interface.
5
6 #ifndef FST_QUEUE_H_
7 #define FST_QUEUE_H_
8
9 #include <deque>
10 #include <memory>
11 #include <type_traits>
12 #include <utility>
13 #include <vector>
14
15 #include <fst/log.h>
16
17 #include <fst/arcfilter.h>
18 #include <fst/connect.h>
19 #include <fst/heap.h>
20 #include <fst/topsort.h>
21
22
23 namespace fst {
24
25 // The Queue interface is:
26 //
27 // template <class S>
28 // class Queue {
29 //  public:
30 //   using StateId = S;
31 //
32 //   // Constructor: may need args (e.g., FST, comparator) for some queues.
33 //   Queue(...) override;
34 //
35 //   // Returns the head of the queue.
36 //   StateId Head() const override;
37 //
38 //   // Inserts a state.
39 //   void Enqueue(StateId s) override;
40 //
41 //   // Removes the head of the queue.
42 //   void Dequeue() override;
43 //
44 //   // Updates ordering of state s when weight changes, if necessary.
45 //   void Update(StateId s) override;
46 //
47 //   // Is the queue empty?
48 //   bool Empty() const override;
49 //
50 //   // Removes all states from the queue.
51 //   void Clear() override;
52 // };
53
54 // State queue types.
55 enum QueueType {
56   TRIVIAL_QUEUE = 0,         // Single state queue.
57   FIFO_QUEUE = 1,            // First-in, first-out queue.
58   LIFO_QUEUE = 2,            // Last-in, first-out queue.
59   SHORTEST_FIRST_QUEUE = 3,  // Shortest-first queue.
60   TOP_ORDER_QUEUE = 4,       // Topologically-ordered queue.
61   STATE_ORDER_QUEUE = 5,     // State ID-ordered queue.
62   SCC_QUEUE = 6,             // Component graph top-ordered meta-queue.
63   AUTO_QUEUE = 7,            // Auto-selected queue.
64   OTHER_QUEUE = 8
65 };
66
67 // QueueBase, templated on the StateId, is a virtual base class shared by all
68 // queues considered by AutoQueue.
69 template <class S>
70 class QueueBase {
71  public:
72   using StateId = S;
73
74   virtual ~QueueBase() {}
75
76   // Concrete implementation.
77
78   explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
79
80   void SetError(bool error) { error_ = error; }
81
82   bool Error() const { return error_; }
83
84   QueueType Type() const { return queue_type_; }
85
86   // Virtual interface.
87
88   virtual StateId Head() const = 0;
89   virtual void Enqueue(StateId) = 0;
90   virtual void Dequeue() = 0;
91   virtual void Update(StateId) = 0;
92   virtual bool Empty() const = 0;
93   virtual void Clear() = 0;
94
95  private:
96   QueueType queue_type_;
97   bool error_;
98 };
99
100 // Trivial queue discipline; one may enqueue at most one state at a time. It
101 // can be used for strongly connected components with only one state and no
102 // self-loops.
103 template <class S>
104 class TrivialQueue : public QueueBase<S> {
105  public:
106   using StateId = S;
107
108   TrivialQueue() : QueueBase<StateId>(TRIVIAL_QUEUE), front_(kNoStateId) {}
109
110   virtual ~TrivialQueue() = default;
111
112   StateId Head() const final { return front_; }
113
114   void Enqueue(StateId s) final { front_ = s; }
115
116   void Dequeue() final { front_ = kNoStateId; }
117
118   void Update(StateId) final {}
119
120   bool Empty() const final { return front_ == kNoStateId; }
121
122   void Clear() final { front_ = kNoStateId; }
123
124  private:
125   StateId front_;
126 };
127
128 // First-in, first-out queue discipline.
129 //
130 // This is not a final class.
131 template <class S>
132 class FifoQueue : public QueueBase<S> {
133  public:
134   using StateId = S;
135
136   FifoQueue() : QueueBase<StateId>(FIFO_QUEUE) {}
137
138   virtual ~FifoQueue() = default;
139
140   StateId Head() const override { return queue_.back(); }
141
142   void Enqueue(StateId s) override { queue_.push_front(s); }
143
144   void Dequeue() override { queue_.pop_back(); }
145
146   void Update(StateId) override {}
147
148   bool Empty() const override { return queue_.empty(); }
149
150   void Clear() override { queue_.clear(); }
151
152  private:
153   std::deque<StateId> queue_;
154 };
155
156 // Last-in, first-out queue discipline.
157 template <class S>
158 class LifoQueue : public QueueBase<S> {
159  public:
160   using StateId = S;
161
162   LifoQueue() : QueueBase<StateId>(LIFO_QUEUE) {}
163
164   virtual ~LifoQueue() = default;
165
166   StateId Head() const final { return queue_.front(); }
167
168   void Enqueue(StateId s) final { queue_.push_front(s); }
169
170   void Dequeue() final { queue_.pop_front(); }
171
172   void Update(StateId) final {}
173
174   bool Empty() const final { return queue_.empty(); }
175
176   void Clear() final { queue_.clear(); }
177
178  private:
179   std::deque<StateId> queue_;
180 };
181
182 // Shortest-first queue discipline, templated on the StateId and as well as a
183 // comparison functor used to compare two StateIds. If a (single) state's order
184 // changes, it can be reordered in the queue with a call to Update(). If update
185 // is false, call to Update() does not reorder the queue.
186 //
187 // This is not a final class.
188 template <typename S, typename Compare, bool update = true>
189 class ShortestFirstQueue : public QueueBase<S> {
190  public:
191   using StateId = S;
192
193   explicit ShortestFirstQueue(Compare comp)
194       : QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
195
196   virtual ~ShortestFirstQueue() = default;
197
198   StateId Head() const override { return heap_.Top(); }
199
200   void Enqueue(StateId s) override {
201     if (update) {
202       for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId);
203       key_[s] = heap_.Insert(s);
204     } else {
205       heap_.Insert(s);
206     }
207   }
208
209   void Dequeue() override {
210     if (update) {
211       key_[heap_.Pop()] = kNoStateId;
212     } else {
213       heap_.Pop();
214     }
215   }
216
217   void Update(StateId s) override {
218     if (!update) return;
219     if (s >= key_.size() || key_[s] == kNoStateId) {
220       Enqueue(s);
221     } else {
222       heap_.Update(key_[s], s);
223     }
224   }
225
226   bool Empty() const override { return heap_.Empty(); }
227
228   void Clear() override {
229     heap_.Clear();
230     if (update) key_.clear();
231   }
232
233  private:
234   Heap<StateId, Compare> heap_;
235   std::vector<ssize_t> key_;
236 };
237
238 namespace internal {
239
240 // Given a vector that maps from states to weights, and a comparison functor
241 // for weights, this class defines a comparison function object between states.
242 template <typename StateId, typename Less>
243 class StateWeightCompare {
244  public:
245   using Weight = typename Less::Weight;
246
247   StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
248       : weights_(weights), less_(less) {}
249
250   bool operator()(const StateId s1, const StateId s2) const {
251     return less_(weights_[s1], weights_[s2]);
252   }
253
254  private:
255   // Borrowed references.
256   const std::vector<Weight> &weights_;
257   const Less &less_;
258 };
259
260 }  // namespace internal
261
262 // Shortest-first queue discipline, templated on the StateId and Weight, is
263 // specialized to use the weight's natural order for the comparison function.
264 template <typename S, typename Weight>
265 class NaturalShortestFirstQueue final
266     : public ShortestFirstQueue<
267           S, internal::StateWeightCompare<S, NaturalLess<Weight>>> {
268  public:
269   using StateId = S;
270   using Compare = internal::StateWeightCompare<StateId, NaturalLess<Weight>>;
271
272   explicit NaturalShortestFirstQueue(const std::vector<Weight> &distance)
273       : ShortestFirstQueue<StateId, Compare>(Compare(distance, less_)) {}
274
275   virtual ~NaturalShortestFirstQueue() = default;
276
277  private:
278   // This is non-static because the constructor for non-idempotent weights will
279   // result in a an error.
280   const NaturalLess<Weight> less_{};
281 };
282
283 // Topological-order queue discipline, templated on the StateId. States are
284 // ordered in the queue topologically. The FST must be acyclic.
285 template <class S>
286 class TopOrderQueue : public QueueBase<S> {
287  public:
288   using StateId = S;
289
290   // This constructor computes the topological order. It accepts an arc filter
291   // to limit the transitions considered in that computation (e.g., only the
292   // epsilon graph).
293   template <class Arc, class ArcFilter>
294   TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
295       : QueueBase<StateId>(TOP_ORDER_QUEUE),
296         front_(0),
297         back_(kNoStateId),
298         order_(0),
299         state_(0) {
300     bool acyclic;
301     TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
302     DfsVisit(fst, &top_order_visitor, filter);
303     if (!acyclic) {
304       FSTERROR() << "TopOrderQueue: FST is not acyclic";
305       QueueBase<S>::SetError(true);
306     }
307     state_.resize(order_.size(), kNoStateId);
308   }
309
310   // This constructor is passed the pre-computed topological order.
311   explicit TopOrderQueue(const std::vector<StateId> &order)
312       : QueueBase<StateId>(TOP_ORDER_QUEUE),
313         front_(0),
314         back_(kNoStateId),
315         order_(order),
316         state_(order.size(), kNoStateId) {}
317
318   virtual ~TopOrderQueue() = default;
319
320   StateId Head() const final { return state_[front_]; }
321
322   void Enqueue(StateId s) final {
323     if (front_ > back_) {
324       front_ = back_ = order_[s];
325     } else if (order_[s] > back_) {
326       back_ = order_[s];
327     } else if (order_[s] < front_) {
328       front_ = order_[s];
329     }
330     state_[order_[s]] = s;
331   }
332
333   void Dequeue() final {
334     state_[front_] = kNoStateId;
335     while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
336   }
337
338   void Update(StateId) final {}
339
340   bool Empty() const final { return front_ > back_; }
341
342   void Clear() final {
343     for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
344     back_ = kNoStateId;
345     front_ = 0;
346   }
347
348  private:
349   StateId front_;
350   StateId back_;
351   std::vector<StateId> order_;
352   std::vector<StateId> state_;
353 };
354
355 // State order queue discipline, templated on the StateId. States are ordered in
356 // the queue by state ID.
357 template <class S>
358 class StateOrderQueue : public QueueBase<S> {
359  public:
360   using StateId = S;
361
362   StateOrderQueue()
363       : QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
364
365   virtual ~StateOrderQueue() = default;
366
367   StateId Head() const final { return front_; }
368
369   void Enqueue(StateId s) final {
370     if (front_ > back_) {
371       front_ = back_ = s;
372     } else if (s > back_) {
373       back_ = s;
374     } else if (s < front_) {
375       front_ = s;
376     }
377     while (enqueued_.size() <= s) enqueued_.push_back(false);
378     enqueued_[s] = true;
379   }
380
381   void Dequeue() final {
382     enqueued_[front_] = false;
383     while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
384   }
385
386   void Update(StateId) final {}
387
388   bool Empty() const final { return front_ > back_; }
389
390   void Clear() final {
391     for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
392     front_ = 0;
393     back_ = kNoStateId;
394   }
395
396  private:
397   StateId front_;
398   StateId back_;
399   std::vector<bool> enqueued_;
400 };
401
402 // SCC topological-order meta-queue discipline, templated on the StateId and a
403 // queue used inside each SCC. It visits the SCCs of an FST in topological
404 // order. Its constructor is passed the queues to to use within an SCC.
405 template <class S, class Queue>
406 class SccQueue : public QueueBase<S> {
407  public:
408   using StateId = S;
409
410   // Constructor takes a vector specifying the SCC number per state and a
411   // vector giving the queue to use per SCC number.
412   SccQueue(const std::vector<StateId> &scc,
413            std::vector<std::unique_ptr<Queue>> *queue)
414       : QueueBase<StateId>(SCC_QUEUE),
415         queue_(queue),
416         scc_(scc),
417         front_(0),
418         back_(kNoStateId) {}
419
420   virtual ~SccQueue() = default;
421
422   StateId Head() const final {
423     while ((front_ <= back_) &&
424            (((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
425             (((*queue_)[front_] == nullptr) &&
426              ((front_ >= trivial_queue_.size()) ||
427               (trivial_queue_[front_] == kNoStateId))))) {
428       ++front_;
429     }
430     if ((*queue_)[front_]) {
431       return (*queue_)[front_]->Head();
432     } else {
433       return trivial_queue_[front_];
434     }
435   }
436
437   void Enqueue(StateId s) final {
438     if (front_ > back_) {
439       front_ = back_ = scc_[s];
440     } else if (scc_[s] > back_) {
441       back_ = scc_[s];
442     } else if (scc_[s] < front_) {
443       front_ = scc_[s];
444     }
445     if ((*queue_)[scc_[s]]) {
446       (*queue_)[scc_[s]]->Enqueue(s);
447     } else {
448       while (trivial_queue_.size() <= scc_[s]) {
449         trivial_queue_.push_back(kNoStateId);
450       }
451       trivial_queue_[scc_[s]] = s;
452     }
453   }
454
455   void Dequeue() final {
456     if ((*queue_)[front_]) {
457       (*queue_)[front_]->Dequeue();
458     } else if (front_ < trivial_queue_.size()) {
459       trivial_queue_[front_] = kNoStateId;
460     }
461   }
462
463   void Update(StateId s) final {
464     if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
465   }
466
467   bool Empty() const final {
468     // Queues SCC number back_ is not empty unless back_ == front_.
469     if (front_ < back_) {
470       return false;
471     } else if (front_ > back_) {
472       return true;
473     } else if ((*queue_)[front_]) {
474       return (*queue_)[front_]->Empty();
475     } else {
476       return (front_ >= trivial_queue_.size()) ||
477              (trivial_queue_[front_] == kNoStateId);
478     }
479   }
480
481   void Clear() final {
482     for (StateId i = front_; i <= back_; ++i) {
483       if ((*queue_)[i]) {
484         (*queue_)[i]->Clear();
485       } else if (i < trivial_queue_.size()) {
486         trivial_queue_[i] = kNoStateId;
487       }
488     }
489     front_ = 0;
490     back_ = kNoStateId;
491   }
492
493  private:
494   std::vector<std::unique_ptr<Queue>> *queue_;
495   const std::vector<StateId> &scc_;
496   mutable StateId front_;
497   StateId back_;
498   std::vector<StateId> trivial_queue_;
499 };
500
501 // Automatic queue discipline. It selects a queue discipline for a given FST
502 // based on its properties.
503 template <class S>
504 class AutoQueue : public QueueBase<S> {
505  public:
506   using StateId = S;
507
508   // This constructor takes a state distance vector that, if non-null and if
509   // the Weight type has the path property, will entertain the shortest-first
510   // queue using the natural order w.r.t to the distance.
511   template <class Arc, class ArcFilter>
512   AutoQueue(const Fst<Arc> &fst,
513             const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
514       : QueueBase<StateId>(AUTO_QUEUE) {
515     using Weight = typename Arc::Weight;
516     // TrivialLess is never instantiated since the construction of Less is
517     // guarded by Properties() & kPath.  It is only here to avoid instantiating
518     // NaturalLess for non-path weights.
519     struct TrivialLess {
520       using Weight = typename Arc::Weight;
521       bool operator()(const Weight &, const Weight &) const { return false; }
522     };
523     using Less =
524         typename std::conditional<(Weight::Properties() & kPath) == kPath,
525                                   NaturalLess<Weight>, TrivialLess>::type;
526     using Compare = internal::StateWeightCompare<StateId, Less>;
527     // First checks if the FST is known to have these properties.
528     const auto props =
529         fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false);
530     if ((props & kTopSorted) || fst.Start() == kNoStateId) {
531       queue_.reset(new StateOrderQueue<StateId>());
532       VLOG(2) << "AutoQueue: using state-order discipline";
533     } else if (props & kAcyclic) {
534       queue_.reset(new TopOrderQueue<StateId>(fst, filter));
535       VLOG(2) << "AutoQueue: using top-order discipline";
536     } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) {
537       queue_.reset(new LifoQueue<StateId>());
538       VLOG(2) << "AutoQueue: using LIFO discipline";
539     } else {
540       uint64 properties;
541       // Decomposes into strongly-connected components.
542       SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
543       DfsVisit(fst, &scc_visitor, filter);
544       auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
545       std::vector<QueueType> queue_types(nscc);
546       std::unique_ptr<Less> less;
547       std::unique_ptr<Compare> comp;
548       if (distance && (Weight::Properties() & kPath)) {
549         less.reset(new Less);
550         comp.reset(new Compare(*distance, *less));
551       }
552       // Finds the queue type to use per SCC.
553       bool unweighted;
554       bool all_trivial;
555       SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
556                    &unweighted);
557       // If unweighted and semiring is idempotent, uses LIFO queue.
558       if (unweighted) {
559         queue_.reset(new LifoQueue<StateId>());
560         VLOG(2) << "AutoQueue: using LIFO discipline";
561         return;
562       }
563       // If all the SCC are trivial, the FST is acyclic and the scc number gives
564       // the topological order.
565       if (all_trivial) {
566         queue_.reset(new TopOrderQueue<StateId>(scc_));
567         VLOG(2) << "AutoQueue: using top-order discipline";
568         return;
569       }
570       VLOG(2) << "AutoQueue: using SCC meta-discipline";
571       queues_.resize(nscc);
572       for (StateId i = 0; i < nscc; ++i) {
573         switch (queue_types[i]) {
574           case TRIVIAL_QUEUE:
575             queues_[i].reset();
576             VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
577             break;
578           case SHORTEST_FIRST_QUEUE:
579             queues_[i].reset(
580                 new ShortestFirstQueue<StateId, Compare, false>(*comp));
581             VLOG(3) << "AutoQueue: SCC #" << i
582                     << ": using shortest-first discipline";
583             break;
584           case LIFO_QUEUE:
585             queues_[i].reset(new LifoQueue<StateId>());
586             VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
587             break;
588           case FIFO_QUEUE:
589           default:
590             queues_[i].reset(new FifoQueue<StateId>());
591             VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
592             break;
593         }
594       }
595       queue_.reset(new SccQueue<StateId, QueueBase<StateId>>(scc_, &queues_));
596     }
597   }
598
599   virtual ~AutoQueue() = default;
600
601   StateId Head() const final { return queue_->Head(); }
602
603   void Enqueue(StateId s) final { queue_->Enqueue(s); }
604
605   void Dequeue() final { queue_->Dequeue(); }
606
607   void Update(StateId s) final { queue_->Update(s); }
608
609   bool Empty() const final { return queue_->Empty(); }
610
611   void Clear() final { queue_->Clear(); }
612
613  private:
614   template <class Arc, class ArcFilter, class Less>
615   static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
616                            std::vector<QueueType> *queue_types,
617                            ArcFilter filter, Less *less, bool *all_trivial,
618                            bool *unweighted);
619
620   std::unique_ptr<QueueBase<StateId>> queue_;
621   std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
622   std::vector<StateId> scc_;
623 };
624
625 // Examines the states in an FST's strongly connected components and determines
626 // which type of queue to use per SCC. Stores result as a vector of QueueTypes
627 // which is assumed to have length equal to the number of SCCs. An arc filter
628 // is used to limit the transitions considered (e.g., only the epsilon graph).
629 // The argument all_trivial is set to true if every queue is the trivial queue.
630 // The argument unweighted is set to true if the semiring is idempotent and all
631 // the arc weights are equal to Zero() or One().
632 template <class StateId>
633 template <class Arc, class ArcFilter, class Less>
634 void AutoQueue<StateId>::SccQueueType(const Fst<Arc> &fst,
635                                       const std::vector<StateId> &scc,
636                                       std::vector<QueueType> *queue_type,
637                                       ArcFilter filter, Less *less,
638                                       bool *all_trivial, bool *unweighted) {
639   using StateId = typename Arc::StateId;
640   using Weight = typename Arc::Weight;
641   *all_trivial = true;
642   *unweighted = true;
643   for (StateId i = 0; i < queue_type->size(); ++i) {
644     (*queue_type)[i] = TRIVIAL_QUEUE;
645   }
646   for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
647     const auto state = sit.Value();
648     for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
649       const auto &arc = ait.Value();
650       if (!filter(arc)) continue;
651       if (scc[state] == scc[arc.nextstate]) {
652         auto &type = (*queue_type)[scc[state]];
653         if (!less || ((*less)(arc.weight, Weight::One()))) {
654           type = FIFO_QUEUE;
655         } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
656           if (!(Weight::Properties() & kIdempotent) ||
657               (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
658             type = SHORTEST_FIRST_QUEUE;
659           } else {
660             type = LIFO_QUEUE;
661           }
662         }
663         if (type != TRIVIAL_QUEUE) *all_trivial = false;
664       }
665       if (!(Weight::Properties() & kIdempotent) ||
666           (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
667         *unweighted = false;
668       }
669     }
670   }
671 }
672
673 // An A* estimate is a function object that maps from a state ID to a an
674 // estimate of the shortest distance to the final states.
675
676 // A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's
677 // algorithm.
678 template <typename StateId, typename Weight>
679 struct TrivialAStarEstimate {
680   const Weight &operator()(StateId) const { return Weight::One(); }
681 };
682
683 // A non-trivial A* estimate using a vector of the estimated future costs.
684 template <typename StateId, typename Weight>
685 class NaturalAStarEstimate {
686  public:
687   NaturalAStarEstimate(const std::vector<Weight> &beta) :
688           beta_(beta) {}
689
690   const Weight &operator()(StateId s) const { return beta_[s]; }
691
692  private:
693   const std::vector<Weight> &beta_;
694 };
695
696 // Given a vector that maps from states to weights representing the shortest
697 // distance from the initial state, a comparison function object between
698 // weights, and an estimate of the shortest distance to the final states, this
699 // class defines a comparison function object between states.
700 template <typename S, typename Less, typename Estimate>
701 class AStarWeightCompare {
702  public:
703   using StateId = S;
704   using Weight = typename Less::Weight;
705
706   AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
707                      const Estimate &estimate)
708       : weights_(weights), less_(less), estimate_(estimate) {}
709
710   bool operator()(StateId s1, StateId s2) const {
711     const auto w1 = Times(weights_[s1], estimate_(s1));
712     const auto w2 = Times(weights_[s2], estimate_(s2));
713     return less_(w1, w2);
714   }
715
716  private:
717   const std::vector<Weight> &weights_;
718   const Less &less_;
719   const Estimate &estimate_;
720 };
721
722 // A* queue discipline templated on StateId, Weight, and Estimate.
723 template <typename S, typename Weight, typename Estimate>
724 class NaturalAStarQueue : public ShortestFirstQueue<
725           S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
726  public:
727   using StateId = S;
728   using Compare = AStarWeightCompare<StateId, NaturalLess<Weight>, Estimate>;
729
730   NaturalAStarQueue(const std::vector<Weight> &distance,
731                     const Estimate &estimate)
732       : ShortestFirstQueue<StateId, Compare>(
733             Compare(distance, less_, estimate)) {}
734
735   ~NaturalAStarQueue() = default;
736
737  private:
738   // This is non-static because the constructor for non-idempotent weights will
739   // result in a an error.
740   const NaturalLess<Weight> less_{};
741 };
742
743 // A state equivalence class is a function object that maps from a state ID to
744 // an equivalence class (state) ID. The trivial equivalence class maps a state
745 // ID to itself.
746 template <typename StateId>
747 struct TrivialStateEquivClass {
748   StateId operator()(StateId s) const { return s; }
749 };
750
751 // Distance-based pruning queue discipline: Enqueues a state only when its
752 // shortest distance (so far), as specified by distance, is less than (as
753 // specified by comp) the shortest distance Times() the threshold to any state
754 // in the same equivalence class, as specified by the functor class_func. The
755 // underlying queue discipline is specified by queue. The ownership of queue is
756 // given to this class.
757 //
758 // This is not a final class.
759 template <typename Queue, typename Less, typename ClassFnc>
760 class PruneQueue : public QueueBase<typename Queue::StateId> {
761  public:
762   using StateId = typename Queue::StateId;
763   using Weight = typename Less::Weight;
764
765   PruneQueue(const std::vector<Weight> &distance, Queue *queue,
766              const Less &less, const ClassFnc &class_fnc, Weight threshold)
767       : QueueBase<StateId>(OTHER_QUEUE),
768         distance_(distance),
769         queue_(queue),
770         less_(less),
771         class_fnc_(class_fnc),
772         threshold_(std::move(threshold)) {}
773
774   virtual ~PruneQueue() = default;
775
776   StateId Head() const override { return queue_->Head(); }
777
778   void Enqueue(StateId s) override {
779     const auto c = class_fnc_(s);
780     if (c >= class_distance_.size()) {
781       class_distance_.resize(c + 1, Weight::Zero());
782     }
783     if (less_(distance_[s], class_distance_[c])) {
784       class_distance_[c] = distance_[s];
785     }
786     // Enqueues only if below threshold limit.
787     const auto limit = Times(class_distance_[c], threshold_);
788     if (less_(distance_[s], limit)) queue_->Enqueue(s);
789   }
790
791   void Dequeue() override { queue_->Dequeue(); }
792
793   void Update(StateId s) override {
794     const auto c = class_fnc_(s);
795     if (less_(distance_[s], class_distance_[c])) {
796       class_distance_[c] = distance_[s];
797     }
798     queue_->Update(s);
799   }
800
801   bool Empty() const override { return queue_->Empty(); }
802
803   void Clear() override { queue_->Clear(); }
804
805  private:
806   const std::vector<Weight> &distance_;  // Shortest distance to state.
807   std::unique_ptr<Queue> queue_;
808   const Less &less_;                    // Borrowed reference.
809   const ClassFnc &class_fnc_;           // Equivalence class functor.
810   Weight threshold_;                    // Pruning weight threshold.
811   std::vector<Weight> class_distance_;  // Shortest distance to class.
812 };
813
814 // Pruning queue discipline (see above) using the weight's natural order for the
815 // comparison function. The ownership of the queue argument is given to this
816 // class.
817 template <typename Queue, typename Weight, typename ClassFnc>
818 class NaturalPruneQueue final
819     : public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
820  public:
821   using StateId = typename Queue::StateId;
822
823   NaturalPruneQueue(const std::vector<Weight> &distance, Queue *queue,
824                     const ClassFnc &class_fnc, Weight threshold)
825       : PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
826             distance, queue, NaturalLess<Weight>(), class_fnc, threshold) {}
827
828   virtual ~NaturalPruneQueue() = default;
829 };
830
831 // Filter-based pruning queue discipline: enqueues a state only if allowed by
832 // the filter, specified by the state filter functor argument. The underlying
833 // queue discipline is specified by the queue argument. The ownership of the
834 // queue is given to this class.
835 template <typename Queue, typename Filter>
836 class FilterQueue : public QueueBase<typename Queue::StateId> {
837  public:
838   using StateId = typename Queue::StateId;
839
840   FilterQueue(Queue *queue, const Filter &filter)
841       : QueueBase<StateId>(OTHER_QUEUE), queue_(queue), filter_(filter) {}
842
843   virtual ~FilterQueue() = default;
844
845   StateId Head() const final { return queue_->Head(); }
846
847   // Enqueues only if allowed by state filter.
848   void Enqueue(StateId s) final {
849     if (filter_(s)) queue_->Enqueue(s);
850   }
851
852   void Dequeue() final { queue_->Dequeue(); }
853
854   void Update(StateId s) final {}
855
856   bool Empty() const final { return queue_->Empty(); }
857
858   void Clear() final { queue_->Clear(); }
859
860  private:
861   std::unique_ptr<Queue> queue_;
862   const Filter &filter_;
863 };
864
865 }  // namespace fst
866
867 #endif  // FST_QUEUE_H_