Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / script / weight-class.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Represents a generic weight in an FST; that is, represents a specific type
5 // of weight underneath while hiding that type from a client.
6
7 #ifndef FST_SCRIPT_WEIGHT_CLASS_H_
8 #define FST_SCRIPT_WEIGHT_CLASS_H_
9
10 #include <memory>
11 #include <ostream>
12 #include <string>
13
14 #include <fst/arc.h>
15 #include <fst/generic-register.h>
16 #include <fst/util.h>
17 #include <fst/weight.h>
18
19 namespace fst {
20 namespace script {
21
22 class WeightImplBase {
23  public:
24   virtual WeightImplBase *Copy() const = 0;
25   virtual void Print(std::ostream *o) const = 0;
26   virtual const string &Type() const = 0;
27   virtual string ToString() const = 0;
28   virtual bool operator==(const WeightImplBase &other) const = 0;
29   virtual bool operator!=(const WeightImplBase &other) const = 0;
30   virtual WeightImplBase &PlusEq(const WeightImplBase &other) = 0;
31   virtual WeightImplBase &TimesEq(const WeightImplBase &other) = 0;
32   virtual WeightImplBase &DivideEq(const WeightImplBase &other) = 0;
33   virtual WeightImplBase &PowerEq(size_t n) = 0;
34   virtual ~WeightImplBase() {}
35 };
36
37 template <class W>
38 class WeightClassImpl : public WeightImplBase {
39  public:
40   explicit WeightClassImpl(const W &weight) : weight_(weight) {}
41
42   WeightClassImpl<W> *Copy() const final {
43     return new WeightClassImpl<W>(weight_);
44   }
45
46   const string &Type() const final { return W::Type(); }
47
48   void Print(std::ostream *ostrm) const final { *ostrm << weight_; }
49
50   string ToString() const final {
51     string str;
52     WeightToStr(weight_, &str);
53     return str;
54   }
55
56   bool operator==(const WeightImplBase &other) const final {
57     const auto *typed_other = static_cast<const WeightClassImpl<W> *>(&other);
58     return weight_ == typed_other->weight_;
59   }
60
61   bool operator!=(const WeightImplBase &other) const final {
62     return !(*this == other);
63   }
64
65   WeightClassImpl<W> &PlusEq(const WeightImplBase &other) final {
66     const auto *typed_other = static_cast<const WeightClassImpl<W> *>(&other);
67     weight_ = Plus(weight_, typed_other->weight_);
68     return *this;
69   }
70
71   WeightClassImpl<W> &TimesEq(const WeightImplBase &other) final {
72     const auto *typed_other = static_cast<const WeightClassImpl<W> *>(&other);
73     weight_ = Times(weight_, typed_other->weight_);
74     return *this;
75   }
76
77   WeightClassImpl<W> &DivideEq(const WeightImplBase &other) final {
78     const auto *typed_other = static_cast<const WeightClassImpl<W> *>(&other);
79     weight_ = Divide(weight_, typed_other->weight_);
80     return *this;
81   }
82
83   WeightClassImpl<W> &PowerEq(size_t n) final {
84     weight_ = Power(weight_, n);
85     return *this;
86   }
87
88   W *GetImpl() { return &weight_; }
89
90  private:
91   W weight_;
92 };
93
94
95 class WeightClass {
96  public:
97   WeightClass() = default;
98
99   template <class W>
100   explicit WeightClass(const W &weight)
101       : impl_(new WeightClassImpl<W>(weight)) {}
102
103   template <class W>
104   explicit WeightClass(const WeightClassImpl<W> &impl)
105       : impl_(new WeightClassImpl<W>(impl)) {}
106
107   WeightClass(const string &weight_type, const string &weight_str);
108
109   WeightClass(const WeightClass &other)
110       : impl_(other.impl_ ? other.impl_->Copy() : nullptr) {}
111
112   WeightClass &operator=(const WeightClass &other) {
113     impl_.reset(other.impl_ ? other.impl_->Copy() : nullptr);
114     return *this;
115   }
116
117   static constexpr const char *__ZERO__ = "__ZERO__";  // NOLINT
118
119   static WeightClass Zero(const string &weight_type);
120
121   static constexpr const char *__ONE__ = "__ONE__";  // NOLINT
122
123   static WeightClass One(const string &weight_type);
124
125   static constexpr const char *__NOWEIGHT__ = "__NOWEIGHT__";  // NOLINT
126
127   static WeightClass NoWeight(const string &weight_type);
128
129   template <class W>
130   const W *GetWeight() const {
131     if (W::Type() != impl_->Type()) {
132        return nullptr;
133     } else {
134       auto *typed_impl = static_cast<WeightClassImpl<W> *>(impl_.get());
135       return typed_impl->GetImpl();
136     }
137   }
138
139   string ToString() const { return (impl_) ? impl_->ToString() : "none"; }
140
141   const string &Type() const {
142     if (impl_) return impl_->Type();
143     static const string *const no_type = new string("none");
144     return *no_type;
145   }
146
147   bool WeightTypesMatch(const WeightClass &other, const string &op_name) const;
148
149   friend bool operator==(const WeightClass &lhs, const WeightClass &rhs);
150
151   friend WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
152
153   friend WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
154
155   friend WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
156
157   friend WeightClass Power(const WeightClass &w, size_t n);
158
159  private:
160   const WeightImplBase *GetImpl() const { return impl_.get(); }
161
162   WeightImplBase *GetImpl() { return impl_.get(); }
163
164   std::unique_ptr<WeightImplBase> impl_;
165
166   friend std::ostream &operator<<(std::ostream &o, const WeightClass &c);
167 };
168
169 bool operator==(const WeightClass &lhs, const WeightClass &rhs);
170
171 bool operator!=(const WeightClass &lhs, const WeightClass &rhs);
172
173 WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
174
175 WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
176
177 WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
178
179 WeightClass Power(const WeightClass &w, size_t n);
180
181 std::ostream &operator<<(std::ostream &o, const WeightClass &c);
182
183 // Registration for generic weight types.
184
185 using StrToWeightImplBaseT = WeightImplBase *(*)(const string &str,
186                                                  const string &src,
187                                                  size_t nline);
188
189 template <class W>
190 WeightImplBase *StrToWeightImplBase(const string &str, const string &src,
191                                     size_t nline) {
192   if (str == WeightClass::__ZERO__)
193     return new WeightClassImpl<W>(W::Zero());
194   else if (str == WeightClass::__ONE__)
195     return new WeightClassImpl<W>(W::One());
196   else if (str == WeightClass::__NOWEIGHT__)
197     return new WeightClassImpl<W>(W::NoWeight());
198   return new WeightClassImpl<W>(StrToWeight<W>(str, src, nline));
199 }
200
201 class WeightClassRegister : public GenericRegister<string, StrToWeightImplBaseT,
202                                                    WeightClassRegister> {
203  protected:
204   string ConvertKeyToSoFilename(const string &key) const final {
205     string legal_type(key);
206     ConvertToLegalCSymbol(&legal_type);
207     return legal_type + ".so";
208   }
209 };
210
211 using WeightClassRegisterer = GenericRegisterer<WeightClassRegister>;
212
213 // Internal version; needs to be called by wrapper in order for macro args to
214 // expand.
215 #define REGISTER_FST_WEIGHT__(Weight, line)                \
216   static WeightClassRegisterer weight_registerer##_##line( \
217       Weight::Type(), StrToWeightImplBase<Weight>)
218
219 // This layer is where __FILE__ and __LINE__ are expanded.
220 #define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \
221   REGISTER_FST_WEIGHT__(Weight, line)
222
223 // Macro for registering new weight types. Clients call this.
224 #define REGISTER_FST_WEIGHT(Weight) \
225   REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__)
226
227 }  // namespace script
228 }  // namespace fst
229
230 #endif  // FST_SCRIPT_WEIGHT_CLASS_H_