Imported Upstream version 1.72.0
[platform/upstream/boost.git] / boost / histogram / accumulators / weighted_mean.hpp
1 // Copyright 2018 Hans Dembinski
2 //
3 // Distributed under the Boost Software License, version 1.0.
4 // (See accompanying file LICENSE_1_0.txt
5 // or copy at http://www.boost.org/LICENSE_1_0.txt)
6
7 #ifndef BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_MEAN_HPP
8 #define BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_MEAN_HPP
9
10 #include <boost/assert.hpp>
11 #include <boost/core/nvp.hpp>
12 #include <boost/histogram/fwd.hpp> // for weighted_mean<>
13 #include <boost/histogram/weight.hpp>
14 #include <type_traits>
15
16 namespace boost {
17 namespace histogram {
18 namespace accumulators {
19
20 /**
21   Calculates mean and variance of weighted sample.
22
23   Uses West's incremental algorithm to improve numerical stability
24   of mean and variance computation.
25 */
26 template <typename RealType>
27 class weighted_mean {
28 public:
29   weighted_mean() = default;
30   weighted_mean(const RealType& wsum, const RealType& wsum2, const RealType& mean,
31                 const RealType& variance)
32       : sum_of_weights_(wsum)
33       , sum_of_weights_squared_(wsum2)
34       , weighted_mean_(mean)
35       , sum_of_weighted_deltas_squared_(
36             variance * (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_)) {}
37
38   void operator()(const RealType& x) { operator()(weight(1), x); }
39
40   void operator()(const weight_type<RealType>& w, const RealType& x) {
41     sum_of_weights_ += w.value;
42     sum_of_weights_squared_ += w.value * w.value;
43     const auto delta = x - weighted_mean_;
44     weighted_mean_ += w.value * delta / sum_of_weights_;
45     sum_of_weighted_deltas_squared_ += w.value * delta * (x - weighted_mean_);
46   }
47
48   template <typename T>
49   weighted_mean& operator+=(const weighted_mean<T>& rhs) {
50     if (sum_of_weights_ != 0 || rhs.sum_of_weights_ != 0) {
51       const auto tmp = weighted_mean_ * sum_of_weights_ +
52                        static_cast<RealType>(rhs.weighted_mean_ * rhs.sum_of_weights_);
53       sum_of_weights_ += static_cast<RealType>(rhs.sum_of_weights_);
54       sum_of_weights_squared_ += static_cast<RealType>(rhs.sum_of_weights_squared_);
55       weighted_mean_ = tmp / sum_of_weights_;
56     }
57     sum_of_weighted_deltas_squared_ +=
58         static_cast<RealType>(rhs.sum_of_weighted_deltas_squared_);
59     return *this;
60   }
61
62   weighted_mean& operator*=(const RealType& s) {
63     weighted_mean_ *= s;
64     sum_of_weighted_deltas_squared_ *= s * s;
65     return *this;
66   }
67
68   template <typename T>
69   bool operator==(const weighted_mean<T>& rhs) const noexcept {
70     return sum_of_weights_ == rhs.sum_of_weights_ &&
71            sum_of_weights_squared_ == rhs.sum_of_weights_squared_ &&
72            weighted_mean_ == rhs.weighted_mean_ &&
73            sum_of_weighted_deltas_squared_ == rhs.sum_of_weighted_deltas_squared_;
74   }
75
76   template <typename T>
77   bool operator!=(const T& rhs) const noexcept {
78     return !operator==(rhs);
79   }
80
81   const RealType& sum_of_weights() const noexcept { return sum_of_weights_; }
82   const RealType& sum_of_weights_squared() const noexcept {
83     return sum_of_weights_squared_;
84   }
85   const RealType& value() const noexcept { return weighted_mean_; }
86   RealType variance() const {
87     return sum_of_weighted_deltas_squared_ /
88            (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_);
89   }
90
91   template <class Archive>
92   void serialize(Archive& ar, unsigned /* version */) {
93     ar& make_nvp("sum_of_weights", sum_of_weights_);
94     ar& make_nvp("sum_of_weights_squared", sum_of_weights_squared_);
95     ar& make_nvp("weighted_mean", weighted_mean_);
96     ar& make_nvp("sum_of_weighted_deltas_squared", sum_of_weighted_deltas_squared_);
97   }
98
99 private:
100   RealType sum_of_weights_ = RealType(), sum_of_weights_squared_ = RealType(),
101            weighted_mean_ = RealType(), sum_of_weighted_deltas_squared_ = RealType();
102 };
103
104 } // namespace accumulators
105 } // namespace histogram
106 } // namespace boost
107
108 #ifndef BOOST_HISTOGRAM_DOXYGEN_INVOKED
109 namespace std {
110 template <class T, class U>
111 /// Specialization for boost::histogram::accumulators::weighted_mean.
112 struct common_type<boost::histogram::accumulators::weighted_mean<T>,
113                    boost::histogram::accumulators::weighted_mean<U>> {
114   using type = boost::histogram::accumulators::weighted_mean<common_type_t<T, U>>;
115 };
116 } // namespace std
117 #endif
118
119 #endif