Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / simple_sum.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef SIMPLE_SUM_HPP
18 #define SIMPLE_SUM_HPP
19
20 #include "cpu_sum.hpp"
21
22 namespace mkldnn {
23 namespace impl {
24 namespace cpu {
25
26 template <data_type_t data_type>
27 struct simple_sum_t: public cpu_primitive_t {
28     using cpu_memory_pd_t = cpu_memory_t::pd_t;
29
30     struct pd_t: public cpu_sum_pd_t {
31         pd_t(const memory_desc_t *output_d, int n, const float *scales,
32                 const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
33             : cpu_sum_pd_t(output_d, n, scales, input_pds, attr) {}
34
35         DECLARE_CPU_SUM_PD_T("simple:any", simple_sum_t);
36
37         virtual status_t init() override {
38             bool ok = true
39                 && cpu_sum_pd_t::init() == success
40                 && src_pds_.size() <= max_num_arrs;
41             if (!ok) return unimplemented;
42
43             const memory_desc_wrapper o_d(&dst_pd_);
44             ok = ok
45                 && o_d.data_type() == data_type
46                 && o_d.is_dense();
47
48             const auto n = src_pds_.size();
49             for (size_t i = 0; i < n; ++i) {
50                 const memory_desc_wrapper i_d(&src_pds_[i]);
51                 ok = ok
52                     && utils::everyone_is(data_type, i_d.data_type())
53                     && i_d.format() == o_d.format()
54                     && i_d.is_dense();
55             }
56
57             return ok ? success : unimplemented;
58         }
59     };
60
61     simple_sum_t(const pd_t *apd, const input_vector &inputs,
62             const output_vector &outputs)
63         : cpu_primitive_t(apd, inputs, outputs) {}
64
65     virtual void execute(event_t *e) const {
66         execute();
67         e->set_state(event_t::ready);
68     }
69
70     enum {max_num_arrs = 16 };
71     typedef typename prec_traits<data_type>::type data_t;
72
73 private:
74     void execute() const;
75     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
76 };
77
78 }
79 }
80 }
81
82 #endif
83
84 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s