Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_sum.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef REF_SUM_HPP
18 #define REF_SUM_HPP
19
20 #include "cpu_sum.hpp"
21 #include "reorder_pd.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26
27 struct ref_sum_t: public cpu_primitive_t {
28     using cpu_memory_pd_t = cpu_memory_t::pd_t;
29
30     struct pd_t: public cpu_sum_pd_t {
31         pd_t(const memory_desc_t *output_d, int n, const float *scales,
32                 const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
33             : cpu_sum_pd_t(output_d, n, scales, input_pds, attr) {}
34         pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) {
35             for (size_t i = 0; i < rhs.scales_.size(); ++i) {
36                 scales_.push_back(rhs.scales_[i]);
37             }
38             for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) {
39                 reorder_pds_.push_back(
40                        (const reorder_pd_t *)rhs.reorder_pds_[i]->clone());
41             }
42         }
43
44         ~pd_t() {
45             for (size_t i = 0; i < reorder_pds_.size(); ++i) {
46                 delete reorder_pds_[i];
47             }
48         }
49
50         static status_t create(sum_pd_t **sum_pd,
51                 const memory_desc_t *output_d, int n, const float *scales,
52                 const memory_pd_t **input_pds, const primitive_attr_t *attr) {
53             auto _pd = new pd_t(output_d, n, scales,
54                     (const cpu_memory_pd_t **)input_pds, attr);
55             if (_pd == nullptr) return out_of_memory;
56             if (_pd->init() != success) { delete _pd; return unimplemented; }
57             return safe_ptr_assign<sum_pd_t>(*sum_pd, _pd);
58         }
59
60         virtual status_t create_primitive(primitive_t **primitive,
61                 const primitive_at_t *inputs, const primitive_t **outputs)
62                 const override {
63             double ms = get_msec();
64             nstl::vector<primitive_t *> reorders;
65             reorders.resize(n_);
66             for (int i = 0; i < n_; ++i)
67                 CHECK(reorder_pds_[i]->create_primitive(&reorders[i],
68                             &inputs[i], outputs));
69
70             primitive_t::input_vector ins(inputs, inputs + n_);
71             primitive_t::output_vector outs(outputs, outputs + 1);
72             auto ret = safe_ptr_assign<primitive_t>(*primitive,
73                      new ref_sum_t(this, ins, outs, reorders));
74             ms = get_msec() - ms;
75             if (mkldnn_verbose()->level >= 2) {
76                 printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms);
77                 fflush(0);
78             }
79             return ret;
80         }
81         virtual pd_t *clone() const override { return new pd_t(*this); }
82         virtual const char *name() const override { return "ref:any"; }
83
84         virtual status_t init() override {
85             bool ok = cpu_sum_pd_t::init() == success;
86             if (!ok) return unimplemented;
87
88             for (int i = 0; i < n_; ++i) {
89                 auto r_impls = engine_->get_reorder_implementation_list();
90                 for (auto r = r_impls; *r; ++r) {
91                     primitive_attr_t dummy_attr;
92                     dummy_attr.output_scales_.set(scales_[i]);
93                     reorder_pd_t *r_pd;
94                     if (i != 0) {
95                         dummy_attr.post_ops_.append_sum(1.0);
96                     }
97                     if ((*r)(&r_pd, &src_pds_[i], &dst_pd_, &dummy_attr)
98                             == status::success) {
99                         r_pd->init_info();
100                         reorder_pds_.push_back(r_pd);
101                         break;
102                     }
103                 }
104             }
105             ok = utils::everyone_is(reorder_pds_.size(), scales_.size());
106             return ok ? success : unimplemented;
107         }
108
109         nstl::vector<const reorder_pd_t *> reorder_pds_;
110     };
111
112     ref_sum_t(const pd_t *apd, const input_vector &inputs,
113             const output_vector &outputs, nstl::vector<primitive_t *> reorders)
114         : cpu_primitive_t(apd, inputs, outputs),
115         reorders_(reorders) {}
116
117     ~ref_sum_t() {
118         const auto n = reorders_.size();
119         for (size_t i = 0; i < n; ++i)
120             delete reorders_[i];
121     }
122
123     virtual void execute(event_t *e) const {
124         const auto n = reorders_.size();
125         for (size_t i = 0; i < n; ++i) {
126             event_t ei;
127             reorders_[i]->execute(&ei);
128         }
129         e->set_state(event_t::ready);
130     }
131
132 private:
133     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
134     nstl::vector<primitive_t *> reorders_;
135 };
136
137 }
138 }
139 }
140
141 #endif