1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Represents a generic weight in an FST; that is, represents a specific type
5 // of weight underneath while hiding that type from a client.
7 #ifndef FST_SCRIPT_WEIGHT_CLASS_H_
8 #define FST_SCRIPT_WEIGHT_CLASS_H_
15 #include <fst/generic-register.h>
17 #include <fst/weight.h>
22 class WeightImplBase {
24 virtual WeightImplBase *Copy() const = 0;
25 virtual void Print(std::ostream *o) const = 0;
26 virtual const string &Type() const = 0;
27 virtual string ToString() const = 0;
28 virtual bool operator==(const WeightImplBase &other) const = 0;
29 virtual bool operator!=(const WeightImplBase &other) const = 0;
30 virtual WeightImplBase &PlusEq(const WeightImplBase &other) = 0;
31 virtual WeightImplBase &TimesEq(const WeightImplBase &other) = 0;
32 virtual WeightImplBase &DivideEq(const WeightImplBase &other) = 0;
33 virtual WeightImplBase &PowerEq(size_t n) = 0;
34 virtual ~WeightImplBase() {}
38 class WeightClassImpl : public WeightImplBase {
40 explicit WeightClassImpl(const W &weight) : weight_(weight) {}
42 WeightClassImpl<W> *Copy() const final {
43 return new WeightClassImpl<W>(weight_);
46 const string &Type() const final { return W::Type(); }
48 void Print(std::ostream *ostrm) const final { *ostrm << weight_; }
50 string ToString() const final {
52 WeightToStr(weight_, &str);
56 bool operator==(const WeightImplBase &other) const final {
57 const auto *typed_other = static_cast<const WeightClassImpl<W> *>(&other);
58 return weight_ == typed_other->weight_;
61 bool operator!=(const WeightImplBase &other) const final {
62 return !(*this == other);
65 WeightClassImpl<W> &PlusEq(const WeightImplBase &other) final {
66 const auto *typed_other = static_cast<const WeightClassImpl<W> *>(&other);
67 weight_ = Plus(weight_, typed_other->weight_);
71 WeightClassImpl<W> &TimesEq(const WeightImplBase &other) final {
72 const auto *typed_other = static_cast<const WeightClassImpl<W> *>(&other);
73 weight_ = Times(weight_, typed_other->weight_);
77 WeightClassImpl<W> &DivideEq(const WeightImplBase &other) final {
78 const auto *typed_other = static_cast<const WeightClassImpl<W> *>(&other);
79 weight_ = Divide(weight_, typed_other->weight_);
83 WeightClassImpl<W> &PowerEq(size_t n) final {
84 weight_ = Power(weight_, n);
88 W *GetImpl() { return &weight_; }
97 WeightClass() = default;
100 explicit WeightClass(const W &weight)
101 : impl_(new WeightClassImpl<W>(weight)) {}
104 explicit WeightClass(const WeightClassImpl<W> &impl)
105 : impl_(new WeightClassImpl<W>(impl)) {}
107 WeightClass(const string &weight_type, const string &weight_str);
109 WeightClass(const WeightClass &other)
110 : impl_(other.impl_ ? other.impl_->Copy() : nullptr) {}
112 WeightClass &operator=(const WeightClass &other) {
113 impl_.reset(other.impl_ ? other.impl_->Copy() : nullptr);
117 static constexpr const char *__ZERO__ = "__ZERO__"; // NOLINT
119 static WeightClass Zero(const string &weight_type);
121 static constexpr const char *__ONE__ = "__ONE__"; // NOLINT
123 static WeightClass One(const string &weight_type);
125 static constexpr const char *__NOWEIGHT__ = "__NOWEIGHT__"; // NOLINT
127 static WeightClass NoWeight(const string &weight_type);
130 const W *GetWeight() const {
131 if (W::Type() != impl_->Type()) {
134 auto *typed_impl = static_cast<WeightClassImpl<W> *>(impl_.get());
135 return typed_impl->GetImpl();
139 string ToString() const { return (impl_) ? impl_->ToString() : "none"; }
141 const string &Type() const {
142 if (impl_) return impl_->Type();
143 static const string *const no_type = new string("none");
147 bool WeightTypesMatch(const WeightClass &other, const string &op_name) const;
149 friend bool operator==(const WeightClass &lhs, const WeightClass &rhs);
151 friend WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
153 friend WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
155 friend WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
157 friend WeightClass Power(const WeightClass &w, size_t n);
160 const WeightImplBase *GetImpl() const { return impl_.get(); }
162 WeightImplBase *GetImpl() { return impl_.get(); }
164 std::unique_ptr<WeightImplBase> impl_;
166 friend std::ostream &operator<<(std::ostream &o, const WeightClass &c);
169 bool operator==(const WeightClass &lhs, const WeightClass &rhs);
171 bool operator!=(const WeightClass &lhs, const WeightClass &rhs);
173 WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
175 WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
177 WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
179 WeightClass Power(const WeightClass &w, size_t n);
181 std::ostream &operator<<(std::ostream &o, const WeightClass &c);
183 // Registration for generic weight types.
185 using StrToWeightImplBaseT = WeightImplBase *(*)(const string &str,
190 WeightImplBase *StrToWeightImplBase(const string &str, const string &src,
192 if (str == WeightClass::__ZERO__)
193 return new WeightClassImpl<W>(W::Zero());
194 else if (str == WeightClass::__ONE__)
195 return new WeightClassImpl<W>(W::One());
196 else if (str == WeightClass::__NOWEIGHT__)
197 return new WeightClassImpl<W>(W::NoWeight());
198 return new WeightClassImpl<W>(StrToWeight<W>(str, src, nline));
201 class WeightClassRegister : public GenericRegister<string, StrToWeightImplBaseT,
202 WeightClassRegister> {
204 string ConvertKeyToSoFilename(const string &key) const final {
205 string legal_type(key);
206 ConvertToLegalCSymbol(&legal_type);
207 return legal_type + ".so";
211 using WeightClassRegisterer = GenericRegisterer<WeightClassRegister>;
213 // Internal version; needs to be called by wrapper in order for macro args to
215 #define REGISTER_FST_WEIGHT__(Weight, line) \
216 static WeightClassRegisterer weight_registerer##_##line( \
217 Weight::Type(), StrToWeightImplBase<Weight>)
219 // This layer is where __FILE__ and __LINE__ are expanded.
220 #define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \
221 REGISTER_FST_WEIGHT__(Weight, line)
223 // Macro for registering new weight types. Clients call this.
224 #define REGISTER_FST_WEIGHT(Weight) \
225 REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__)
227 } // namespace script
230 #endif // FST_SCRIPT_WEIGHT_CLASS_H_