bc74a64e33abeed70d7e156dceee08d641ea2266
[platform/upstream/openfst.git] / src / include / fst / queue.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes for various FST state queues with a unified interface.
5
6 #ifndef FST_QUEUE_H_
7 #define FST_QUEUE_H_
8
9 #include <deque>
10 #include <memory>
11 #include <type_traits>
12 #include <utility>
13 #include <vector>
14
15 #include <fst/log.h>
16
17 #include <fst/arcfilter.h>
18 #include <fst/connect.h>
19 #include <fst/heap.h>
20 #include <fst/topsort.h>
21
22
23 namespace fst {
24
25 // The Queue interface is:
26 //
27 // template <class S>
28 // class Queue {
29 //  public:
30 //   using StateId = S;
31 //
32 //   // Constructor: may need args (e.g., FST, comparator) for some queues.
33 //   Queue(...) override;
34 //
35 //   // Returns the head of the queue.
36 //   StateId Head() const override;
37 //
38 //   // Inserts a state.
39 //   void Enqueue(StateId s) override;
40 //
41 //   // Removes the head of the queue.
42 //   void Dequeue() override;
43 //
44 //   // Updates ordering of state s when weight changes, if necessary.
45 //   void Update(StateId s) override;
46 //
47 //   // Is the queue empty?
48 //   bool Empty() const override;
49 //
50 //   // Removes all states from the queue.
51 //   void Clear() override;
52 // };
53
54 // State queue types.
55 enum QueueType {
56   TRIVIAL_QUEUE = 0,         // Single state queue.
57   FIFO_QUEUE = 1,            // First-in, first-out queue.
58   LIFO_QUEUE = 2,            // Last-in, first-out queue.
59   SHORTEST_FIRST_QUEUE = 3,  // Shortest-first queue.
60   TOP_ORDER_QUEUE = 4,       // Topologically-ordered queue.
61   STATE_ORDER_QUEUE = 5,     // State ID-ordered queue.
62   SCC_QUEUE = 6,             // Component graph top-ordered meta-queue.
63   AUTO_QUEUE = 7,            // Auto-selected queue.
64   OTHER_QUEUE = 8
65 };
66
67 // QueueBase, templated on the StateId, is a virtual base class shared by all
68 // queues considered by AutoQueue.
69 template <class S>
70 class QueueBase {
71  public:
72   using StateId = S;
73
74   virtual ~QueueBase() {}
75
76   // Concrete implementation.
77
78   explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
79
80   void SetError(bool error) { error_ = error; }
81
82   bool Error() const { return error_; }
83
84   QueueType Type() const { return queue_type_; }
85
86   // Virtual interface.
87
88   virtual StateId Head() const = 0;
89   virtual void Enqueue(StateId) = 0;
90   virtual void Dequeue() = 0;
91   virtual void Update(StateId) = 0;
92   virtual bool Empty() const = 0;
93   virtual void Clear() = 0;
94
95  private:
96   QueueType queue_type_;
97   bool error_;
98 };
99
100 // Trivial queue discipline; one may enqueue at most one state at a time. It
101 // can be used for strongly connected components with only one state and no
102 // self-loops.
103 template <class S>
104 class TrivialQueue final : public QueueBase<S> {
105  public:
106   using StateId = S;
107
108   TrivialQueue() : QueueBase<StateId>(TRIVIAL_QUEUE), front_(kNoStateId) {}
109
110   virtual ~TrivialQueue() = default;
111
112   StateId Head() const override { return front_; }
113
114   void Enqueue(StateId s) override { front_ = s; }
115
116   void Dequeue() override { front_ = kNoStateId; }
117
118   void Update(StateId) override {}
119
120   bool Empty() const override { return front_ == kNoStateId; }
121
122   void Clear() override { front_ = kNoStateId; }
123
124  private:
125   StateId front_;
126 };
127
128 // First-in, first-out queue discipline.
129 //
130 // This is not a final class.
131 template <class S>
132 class FifoQueue : public QueueBase<S> {
133  public:
134   using StateId = S;
135
136   FifoQueue() : QueueBase<StateId>(FIFO_QUEUE) {}
137
138   virtual ~FifoQueue() = default;
139
140   StateId Head() const override { return queue_.back(); }
141
142   void Enqueue(StateId s) override { queue_.push_front(s); }
143
144   void Dequeue() override { queue_.pop_back(); }
145
146   void Update(StateId) override {}
147
148   bool Empty() const override { return queue_.empty(); }
149
150   void Clear() override { queue_.clear(); }
151
152  private:
153   std::deque<StateId> queue_;
154 };
155
156 // Last-in, first-out queue discipline.
157 template <class S>
158 class LifoQueue final : public QueueBase<S> {
159  public:
160   using StateId = S;
161
162   LifoQueue() : QueueBase<StateId>(LIFO_QUEUE) {}
163
164   virtual ~LifoQueue() = default;
165
166   StateId Head() const override { return queue_.front(); }
167
168   void Enqueue(StateId s) override { queue_.push_front(s); }
169
170   void Dequeue() override { queue_.pop_front(); }
171
172   void Update(StateId) override {}
173
174   bool Empty() const override { return queue_.empty(); }
175
176   void Clear() override { queue_.clear(); }
177
178  private:
179   std::deque<StateId> queue_;
180 };
181
182 // Shortest-first queue discipline, templated on the StateId and as well as a
183 // comparison functor used to compare two StateIds. If a (single) state's order
184 // changes, it can be reordered in the queue with a call to Update(). If update
185 // is false, call to Update() does not reorder the queue.
186 //
187 // This is not a final class.
188 template <typename S, typename Compare, bool update = true>
189 class ShortestFirstQueue : public QueueBase<S> {
190  public:
191   using StateId = S;
192
193   static constexpr StateId kNoKey = -1;
194
195   explicit ShortestFirstQueue(Compare comp)
196       : QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
197
198   StateId Head() const override { return heap_.Top(); }
199
200   void Enqueue(StateId s) override {
201     if (update) {
202       for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoKey);
203       key_[s] = heap_.Insert(s);
204     } else {
205       heap_.Insert(s);
206     }
207   }
208
209   void Dequeue() override {
210     if (update) {
211       key_[heap_.Pop()] = kNoKey;
212     } else {
213       heap_.Pop();
214     }
215   }
216
217   void Update(StateId s) override {
218     if (!update) return;
219     if (s >= key_.size() || key_[s] == kNoKey) {
220       Enqueue(s);
221     } else {
222       heap_.Update(key_[s], s);
223     }
224   }
225
226   bool Empty() const override { return heap_.Empty(); }
227
228   void Clear() override {
229     heap_.Clear();
230     if (update) key_.clear();
231   }
232
233  private:
234   Heap<StateId, Compare> heap_;
235   std::vector<ssize_t> key_;
236 };
237
238 template <typename StateId, typename Compare, bool update>
239 constexpr StateId ShortestFirstQueue<StateId, Compare, update>::kNoKey;
240
241 namespace internal {
242
243 // Given a vector that maps from states to weights, and a comparison functor
244 // for weights, this class defines a comparison function object between states.
245 template <typename StateId, typename Less>
246 class StateWeightCompare {
247  public:
248   using Weight = typename Less::Weight;
249
250   StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
251       : weights_(weights), less_(less) {}
252
253   bool operator()(const StateId s1, const StateId s2) const {
254     return less_(weights_[s1], weights_[s2]);
255   }
256
257  private:
258   // Borrowed references.
259   const std::vector<Weight> &weights_;
260   const Less &less_;
261 };
262
263 }  // namespace internal
264
265 // Shortest-first queue discipline, templated on the StateId and Weight, is
266 // specialized to use the weight's natural order for the comparison function.
267 template <typename S, typename Weight>
268 class NaturalShortestFirstQueue final
269     : public ShortestFirstQueue<
270           S, internal::StateWeightCompare<S, NaturalLess<Weight>>> {
271  public:
272   using StateId = S;
273   using Compare = internal::StateWeightCompare<StateId, NaturalLess<Weight>>;
274
275   explicit NaturalShortestFirstQueue(const std::vector<Weight> &distance)
276       : ShortestFirstQueue<StateId, Compare>(Compare(distance, less_)) {}
277
278   virtual ~NaturalShortestFirstQueue() = default;
279
280  private:
281   // This is non-static because the constructor for non-idempotent weights will
282   // result in a an error.
283   const NaturalLess<Weight> less_{};
284 };
285
286 // Topological-order queue discipline, templated on the StateId. States are
287 // ordered in the queue topologically. The FST must be acyclic.
288 template <class S>
289 class TopOrderQueue final : public QueueBase<S> {
290  public:
291   using StateId = S;
292
293   // This constructor computes the topological order. It accepts an arc filter
294   // to limit the transitions considered in that computation (e.g., only the
295   // epsilon graph).
296   template <class Arc, class ArcFilter>
297   TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
298       : QueueBase<StateId>(TOP_ORDER_QUEUE),
299         front_(0),
300         back_(kNoStateId),
301         order_(0),
302         state_(0) {
303     bool acyclic;
304     TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
305     DfsVisit(fst, &top_order_visitor, filter);
306     if (!acyclic) {
307       FSTERROR() << "TopOrderQueue: FST is not acyclic";
308       QueueBase<S>::SetError(true);
309     }
310     state_.resize(order_.size(), kNoStateId);
311   }
312
313   // This constructor is passed the pre-computed topological order.
314   explicit TopOrderQueue(const std::vector<StateId> &order)
315       : QueueBase<StateId>(TOP_ORDER_QUEUE),
316         front_(0),
317         back_(kNoStateId),
318         order_(order),
319         state_(order.size(), kNoStateId) {}
320
321   virtual ~TopOrderQueue() = default;
322
323   StateId Head() const override { return state_[front_]; }
324
325   void Enqueue(StateId s) override {
326     if (front_ > back_) {
327       front_ = back_ = order_[s];
328     } else if (order_[s] > back_) {
329       back_ = order_[s];
330     } else if (order_[s] < front_) {
331       front_ = order_[s];
332     }
333     state_[order_[s]] = s;
334   }
335
336   void Dequeue() override {
337     state_[front_] = kNoStateId;
338     while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
339   }
340
341   void Update(StateId) override {}
342
343   bool Empty() const override { return front_ > back_; }
344
345   void Clear() override {
346     for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
347     back_ = kNoStateId;
348     front_ = 0;
349   }
350
351  private:
352   StateId front_;
353   StateId back_;
354   std::vector<StateId> order_;
355   std::vector<StateId> state_;
356 };
357
358 // State order queue discipline, templated on the StateId. States are ordered in
359 // the queue by state ID.
360 template <class S>
361 class StateOrderQueue final : public QueueBase<S> {
362  public:
363   using StateId = S;
364
365   StateOrderQueue()
366       : QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
367
368   virtual ~StateOrderQueue() = default;
369
370   StateId Head() const override { return front_; }
371
372   void Enqueue(StateId s) override {
373     if (front_ > back_) {
374       front_ = back_ = s;
375     } else if (s > back_) {
376       back_ = s;
377     } else if (s < front_) {
378       front_ = s;
379     }
380     while (enqueued_.size() <= s) enqueued_.push_back(false);
381     enqueued_[s] = true;
382   }
383
384   void Dequeue() override {
385     enqueued_[front_] = false;
386     while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
387   }
388
389   void Update(StateId) override {}
390
391   bool Empty() const override { return front_ > back_; }
392
393   void Clear() override {
394     for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
395     front_ = 0;
396     back_ = kNoStateId;
397   }
398
399  private:
400   StateId front_;
401   StateId back_;
402   std::vector<bool> enqueued_;
403 };
404
405 // SCC topological-order meta-queue discipline, templated on the StateId and a
406 // queue used inside each SCC. It visits the SCCs of an FST in topological
407 // order. Its constructor is passed the queues to to use within an SCC.
408 template <class S, class Queue>
409 class SccQueue final : public QueueBase<S> {
410  public:
411   using StateId = S;
412
413   // Constructor takes a vector specifying the SCC number per state and a
414   // vector giving the queue to use per SCC number.
415   SccQueue(const std::vector<StateId> &scc,
416            std::vector<std::unique_ptr<Queue>> *queue)
417       : QueueBase<StateId>(SCC_QUEUE),
418         queue_(queue),
419         scc_(scc),
420         front_(0),
421         back_(kNoStateId) {}
422
423   virtual ~SccQueue() = default;
424
425   StateId Head() const override {
426     while ((front_ <= back_) &&
427            (((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
428             (((*queue_)[front_] == nullptr) &&
429              ((front_ >= trivial_queue_.size()) ||
430               (trivial_queue_[front_] == kNoStateId))))) {
431       ++front_;
432     }
433     if ((*queue_)[front_]) {
434       return (*queue_)[front_]->Head();
435     } else {
436       return trivial_queue_[front_];
437     }
438   }
439
440   void Enqueue(StateId s) override {
441     if (front_ > back_) {
442       front_ = back_ = scc_[s];
443     } else if (scc_[s] > back_) {
444       back_ = scc_[s];
445     } else if (scc_[s] < front_) {
446       front_ = scc_[s];
447     }
448     if ((*queue_)[scc_[s]]) {
449       (*queue_)[scc_[s]]->Enqueue(s);
450     } else {
451       while (trivial_queue_.size() <= scc_[s]) {
452         trivial_queue_.push_back(kNoStateId);
453       }
454       trivial_queue_[scc_[s]] = s;
455     }
456   }
457
458   void Dequeue() override {
459     if ((*queue_)[front_]) {
460       (*queue_)[front_]->Dequeue();
461     } else if (front_ < trivial_queue_.size()) {
462       trivial_queue_[front_] = kNoStateId;
463     }
464   }
465
466   void Update(StateId s) override {
467     if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
468   }
469
470   bool Empty() const override {
471     // Queue SCC number back_ is not empty unless back_ == front_.
472     if (front_ < back_) {
473       return false;
474     } else if (front_ > back_) {
475       return true;
476     } else if ((*queue_)[front_]) {
477       return (*queue_)[front_]->Empty();
478     } else {
479       return (front_ >= trivial_queue_.size()) ||
480              (trivial_queue_[front_] == kNoStateId);
481     }
482   }
483
484   void Clear() override {
485     for (StateId i = front_; i <= back_; ++i) {
486       if ((*queue_)[i]) {
487         (*queue_)[i]->Clear();
488       } else if (i < trivial_queue_.size()) {
489         trivial_queue_[i] = kNoStateId;
490       }
491     }
492     front_ = 0;
493     back_ = kNoStateId;
494   }
495
496  private:
497   std::vector<std::unique_ptr<Queue>> *queue_;
498   const std::vector<StateId> &scc_;
499   mutable StateId front_;
500   StateId back_;
501   std::vector<StateId> trivial_queue_;
502 };
503
504 // Automatic queue discipline. It selects a queue discipline for a given FST
505 // based on its properties.
506 template <class S>
507 class AutoQueue final : public QueueBase<S> {
508  public:
509   using StateId = S;
510
511   // This constructor takes a state distance vector that, if non-null and if
512   // the Weight type has the path property, will entertain the shortest-first
513   // queue using the natural order w.r.t to the distance.
514   template <class Arc, class ArcFilter>
515   AutoQueue(const Fst<Arc> &fst,
516             const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
517       : QueueBase<StateId>(AUTO_QUEUE) {
518     using Weight = typename Arc::Weight;
519     // TrivialLess is never instantiated since the construction of Less is
520     // guarded by Properties() & kPath.  It is only here to avoid instantiating
521     // NaturalLess for non-path weights.
522     struct TrivialLess {
523       using Weight = typename Arc::Weight;
524       bool operator()(const Weight &, const Weight &) const { return false; }
525     };
526     using Less =
527         typename std::conditional<(Weight::Properties() & kPath) == kPath,
528                                   NaturalLess<Weight>, TrivialLess>::type;
529     using Compare = internal::StateWeightCompare<StateId, Less>;
530     // First checks if the FST is known to have these properties.
531     const auto props =
532         fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false);
533     if ((props & kTopSorted) || fst.Start() == kNoStateId) {
534       queue_.reset(new StateOrderQueue<StateId>());
535       VLOG(2) << "AutoQueue: using state-order discipline";
536     } else if (props & kAcyclic) {
537       queue_.reset(new TopOrderQueue<StateId>(fst, filter));
538       VLOG(2) << "AutoQueue: using top-order discipline";
539     } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) {
540       queue_.reset(new LifoQueue<StateId>());
541       VLOG(2) << "AutoQueue: using LIFO discipline";
542     } else {
543       uint64 properties;
544       // Decomposes into strongly-connected components.
545       SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
546       DfsVisit(fst, &scc_visitor, filter);
547       auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
548       std::vector<QueueType> queue_types(nscc);
549       std::unique_ptr<Less> less;
550       std::unique_ptr<Compare> comp;
551       if (distance && (Weight::Properties() & kPath)) {
552         less.reset(new Less);
553         comp.reset(new Compare(*distance, *less));
554       }
555       // Finds the queue type to use per SCC.
556       bool unweighted;
557       bool all_trivial;
558       SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
559                    &unweighted);
560       // If unweighted and semiring is idempotent, uses LIFO queue.
561       if (unweighted) {
562         queue_.reset(new LifoQueue<StateId>());
563         VLOG(2) << "AutoQueue: using LIFO discipline";
564         return;
565       }
566       // If all the SCC are trivial, the FST is acyclic and the scc number gives
567       // the topological order.
568       if (all_trivial) {
569         queue_.reset(new TopOrderQueue<StateId>(scc_));
570         VLOG(2) << "AutoQueue: using top-order discipline";
571         return;
572       }
573       VLOG(2) << "AutoQueue: using SCC meta-discipline";
574       queues_.resize(nscc);
575       for (StateId i = 0; i < nscc; ++i) {
576         switch (queue_types[i]) {
577           case TRIVIAL_QUEUE:
578             queues_[i].reset();
579             VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
580             break;
581           case SHORTEST_FIRST_QUEUE:
582             queues_[i].reset(
583                 new ShortestFirstQueue<StateId, Compare, false>(*comp));
584             VLOG(3) << "AutoQueue: SCC #" << i
585                     << ": using shortest-first discipline";
586             break;
587           case LIFO_QUEUE:
588             queues_[i].reset(new LifoQueue<StateId>());
589             VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
590             break;
591           case FIFO_QUEUE:
592           default:
593             queues_[i].reset(new FifoQueue<StateId>());
594             VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
595             break;
596         }
597       }
598       queue_.reset(new SccQueue<StateId, QueueBase<StateId>>(scc_, &queues_));
599     }
600   }
601
602   virtual ~AutoQueue() = default;
603
604   StateId Head() const override { return queue_->Head(); }
605
606   void Enqueue(StateId s) override { queue_->Enqueue(s); }
607
608   void Dequeue() override { queue_->Dequeue(); }
609
610   void Update(StateId s) override { queue_->Update(s); }
611
612   bool Empty() const override { return queue_->Empty(); }
613
614   void Clear() override { queue_->Clear(); }
615
616  private:
617   template <class Arc, class ArcFilter, class Less>
618   static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
619                            std::vector<QueueType> *queue_types,
620                            ArcFilter filter, Less *less, bool *all_trivial,
621                            bool *unweighted);
622
623   std::unique_ptr<QueueBase<StateId>> queue_;
624   std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
625   std::vector<StateId> scc_;
626 };
627
628 // Examines the states in an FST's strongly connected components and determines
629 // which type of queue to use per SCC. Stores result as a vector of QueueTypes
630 // which is assumed to have length equal to the number of SCCs. An arc filter
631 // is used to limit the transitions considered (e.g., only the epsilon graph).
632 // The argument all_trivial is set to true if every queue is the trivial queue.
633 // The argument unweighted is set to true if the semiring is idempotent and all
634 // the arc weights are equal to Zero() or One().
635 template <class StateId>
636 template <class Arc, class ArcFilter, class Less>
637 void AutoQueue<StateId>::SccQueueType(const Fst<Arc> &fst,
638                                       const std::vector<StateId> &scc,
639                                       std::vector<QueueType> *queue_type,
640                                       ArcFilter filter, Less *less,
641                                       bool *all_trivial, bool *unweighted) {
642   using StateId = typename Arc::StateId;
643   using Weight = typename Arc::Weight;
644   *all_trivial = true;
645   *unweighted = true;
646   for (StateId i = 0; i < queue_type->size(); ++i) {
647     (*queue_type)[i] = TRIVIAL_QUEUE;
648   }
649   for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
650     const auto state = sit.Value();
651     for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
652       const auto &arc = ait.Value();
653       if (!filter(arc)) continue;
654       if (scc[state] == scc[arc.nextstate]) {
655         auto &type = (*queue_type)[scc[state]];
656         if (!less || ((*less)(arc.weight, Weight::One()))) {
657           type = FIFO_QUEUE;
658         } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
659           if (!(Weight::Properties() & kIdempotent) ||
660               (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
661             type = SHORTEST_FIRST_QUEUE;
662           } else {
663             type = LIFO_QUEUE;
664           }
665         }
666         if (type != TRIVIAL_QUEUE) *all_trivial = false;
667       }
668       if (!(Weight::Properties() & kIdempotent) ||
669           (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
670         *unweighted = false;
671       }
672     }
673   }
674 }
675
676 // An A* estimate is a function object that maps from a state ID to a an
677 // estimate of the shortest distance to the final states.
678
679 // The trivial A* estimate is always One().
680 template <typename StateId, typename Weight>
681 struct TrivialAStarEstimate {
682   Weight operator()(StateId) const { return Weight::One(); }
683 };
684
685 // Given a vector that maps from states to weights representing the shortest
686 // distance from the initial state, a comparison function object between
687 // weights, and an estimate of the shortest distance to the final states, this
688 // class defines a comparison function object between states.
689 template <typename S, typename Less, typename Estimate>
690 class AStarWeightCompare {
691  public:
692   using StateId = S;
693   using Weight = typename Less::Weight;
694
695   AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
696                      const Estimate &estimate)
697       : weights_(weights), less_(less), estimate_(estimate) {}
698
699   bool operator()(const StateId s1, const StateId s2) const {
700     const auto w1 = Times(weights_[s1], estimate_(s1));
701     const auto w2 = Times(weights_[s2], estimate_(s2));
702     return less_(w1, w2);
703   }
704
705  private:
706   // Borrowed references.
707   const std::vector<Weight> &weights_;
708   const Less &less_;
709   const Estimate &estimate_;
710 };
711
712 // A* queue discipline templated on StateId, Weight, and Estimate.
713 template <typename S, typename Weight, typename Estimate>
714 class NaturalAStarQueue final
715     : public ShortestFirstQueue<
716           S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
717  public:
718   using StateId = S;
719   using Compare = AStarWeightCompare<StateId, NaturalLess<Weight>, Estimate>;
720
721   NaturalAStarQueue(const std::vector<Weight> &distance,
722                     const Estimate &estimate)
723       : ShortestFirstQueue<StateId, Compare>(
724             Compare(distance, less_, estimate)) {}
725
726   virtual ~NaturalAStarQueue() = default;
727
728  private:
729   // This is non-static because the constructor for non-idempotent weights will
730   // result in a an error.
731   const NaturalLess<Weight> less_{};
732 };
733
734 // A state equivalence class is a function object that maps from a state ID to
735 // an equivalence class (state) ID. The trivial equivalence class maps a state
736 // ID to itself.
737 template <typename StateId>
738 struct TrivialStateEquivClass {
739   StateId operator()(StateId s) const { return s; }
740 };
741
742 // Distance-based pruning queue discipline: Enqueues a state only when its
743 // shortest distance (so far), as specified by distance, is less than (as
744 // specified by comp) the shortest distance Times() the threshold to any state
745 // in the same equivalence class, as specified by the functor class_func. The
746 // underlying queue discipline is specified by queue. The ownership of queue is
747 // given to this class.
748 //
749 // This is not a final class.
750 template <typename Queue, typename Less, typename ClassFnc>
751 class PruneQueue : public QueueBase<typename Queue::StateId> {
752  public:
753   using StateId = typename Queue::StateId;
754   using Weight = typename Less::Weight;
755
756   PruneQueue(const std::vector<Weight> &distance, Queue *queue,
757              const Less &less, const ClassFnc &class_fnc, Weight threshold)
758       : QueueBase<StateId>(OTHER_QUEUE),
759         distance_(distance),
760         queue_(queue),
761         less_(less),
762         class_fnc_(class_fnc),
763         threshold_(std::move(threshold)) {}
764
765   virtual ~PruneQueue() = default;
766
767   StateId Head() const override { return queue_->Head(); }
768
769   void Enqueue(StateId s) override {
770     const auto c = class_fnc_(s);
771     if (c >= class_distance_.size()) {
772       class_distance_.resize(c + 1, Weight::Zero());
773     }
774     if (less_(distance_[s], class_distance_[c])) {
775       class_distance_[c] = distance_[s];
776     }
777     // Enqueues only if below threshold limit.
778     const auto limit = Times(class_distance_[c], threshold_);
779     if (less_(distance_[s], limit)) queue_->Enqueue(s);
780   }
781
782   void Dequeue() override { queue_->Dequeue(); }
783
784   void Update(StateId s) override {
785     const auto c = class_fnc_(s);
786     if (less_(distance_[s], class_distance_[c])) {
787       class_distance_[c] = distance_[s];
788     }
789     queue_->Update(s);
790   }
791
792   bool Empty() const override { return queue_->Empty(); }
793
794   void Clear() override { queue_->Clear(); }
795
796  private:
797   const std::vector<Weight> &distance_;  // Shortest distance to state.
798   std::unique_ptr<Queue> queue_;
799   const Less &less_;                    // Borrowed reference.
800   const ClassFnc &class_fnc_;           // Equivalence class functor.
801   Weight threshold_;                    // Pruning weight threshold.
802   std::vector<Weight> class_distance_;  // Shortest distance to class.
803 };
804
805 // Pruning queue discipline (see above) using the weight's natural order for the
806 // comparison function. The ownership of the queue argument is given to this
807 // class.
808 template <typename Queue, typename Weight, typename ClassFnc>
809 class NaturalPruneQueue final
810     : public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
811  public:
812   using StateId = typename Queue::StateId;
813
814   NaturalPruneQueue(const std::vector<Weight> &distance, Queue *queue,
815                     const ClassFnc &class_fnc, Weight threshold)
816       : PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
817             distance, queue, NaturalLess<Weight>(), class_fnc, threshold) {}
818
819   virtual ~NaturalPruneQueue() = default;
820 };
821
822 // Filter-based pruning queue discipline: enqueues a state only if allowed by
823 // the filter, specified by the state filter functor argument. The underlying
824 // queue discipline is specified by the queue argument. The ownership of the
825 // queue is given to this class.
826 template <typename Queue, typename Filter>
827 class FilterQueue final : public QueueBase<typename Queue::StateId> {
828  public:
829   using StateId = typename Queue::StateId;
830
831   FilterQueue(Queue *queue, const Filter &filter)
832       : QueueBase<StateId>(OTHER_QUEUE), queue_(queue), filter_(filter) {}
833
834   virtual ~FilterQueue() = default;
835
836   StateId Head() const override { return queue_->Head(); }
837
838   // Enqueues only if allowed by state filter.
839   void Enqueue(StateId s) override {
840     if (filter_(s)) queue_->Enqueue(s);
841   }
842
843   void Dequeue() override { queue_->Dequeue(); }
844
845   void Update(StateId s) override {}
846
847   bool Empty() const override { return queue_->Empty(); }
848
849   void Clear() override { queue_->Clear(); }
850
851  private:
852   std::unique_ptr<Queue> queue_;
853   const Filter &filter_;
854 };
855
856 }  // namespace fst
857
858 #endif  // FST_QUEUE_H_