updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / simple_sum.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef SIMPLE_SUM_HPP
18 #define SIMPLE_SUM_HPP
19
20 #include "cpu_sum.hpp"
21 #include "cpu_isa_traits.hpp"
22 #include "bfloat16_utils.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 namespace {
29 struct sum_bf16_params_t {
30     size_t ws_cvt_elements_per_thread_;
31     size_t ws_acc_elements_per_thread_;
32     size_t ws_elements_per_thread_;
33     size_t acc_loop_step_;
34 };
35 }
36
37 template <data_type_t src_data_type, data_type_t dst_data_type>
38 struct simple_sum_t: public cpu_primitive_t {
39     using cpu_memory_pd_t = cpu_memory_t::pd_t;
40
41     struct pd_t: public cpu_sum_pd_t {
42         pd_t(const memory_desc_t *output_d, int n, const float *scales,
43              const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
44             : cpu_sum_pd_t(output_d, n, scales, input_pds, attr) {}
45
46         DECLARE_CPU_SUM_PD_T("simple:any", simple_sum_t);
47
48         virtual status_t init() override {
49             bool ok = true
50                 && cpu_sum_pd_t::init() == success
51                 && src_pds_.size() <= max_num_arrs;
52             if (!ok) return unimplemented;
53
54             const memory_desc_wrapper o_d(&dst_pd_);
55             ok = ok
56                 && o_d.data_type() == dst_data_type
57                 && o_d.is_dense();
58             if (!ok) return unimplemented;
59
60             const auto n = src_pds_.size();
61             for (size_t i = 0; i < n; ++i) {
62                 const memory_desc_wrapper i_d(&src_pds_[i]);
63                 ok = true
64                     && utils::everyone_is(src_data_type, i_d.data_type())
65                     && i_d.format() == o_d.format()
66                     && i_d.is_dense();
67                 if (!ok) return unimplemented;
68             }
69
70             compute_blocking();
71             init_scratchpad();
72
73             return success;
74         }
75
76         sum_bf16_params_t bf16_p_;
77         size_t block_size_, nelems_, blocks_number_, tail_;
78
79         private:
80
81             const size_t cacheline_size_ = 64; // bytes
82             const size_t half_L1_size_ = 16 * 1024; // bytes
83
84             void compute_blocking() {
85                 block_size_ = (src_data_type == data_type::bf16
86                         ?  16 * cacheline_size_
87                         : half_L1_size_)
88                     / sizeof(src_data_type);
89                 nelems_ = memory_desc_wrapper(dst_pd()).nelems();
90                 blocks_number_ = nelems_ / block_size_;
91                 tail_ = nelems_ % block_size_;
92             }
93
94             void init_scratchpad() {
95                 if (src_data_type == data_type::bf16) {
96                     bool is_dst_bf16_ = dst_data_type == data_type::bf16;
97                     bf16_p_.ws_cvt_elements_per_thread_ =
98                         cacheline_size_ / sizeof(acc_data_t);
99
100                     bf16_p_.ws_acc_elements_per_thread_ =
101                         is_dst_bf16_
102                         ? bf16_p_.ws_cvt_elements_per_thread_
103                         : 0;
104
105                     bf16_p_.acc_loop_step_ = is_dst_bf16_
106                         ? bf16_p_.ws_cvt_elements_per_thread_
107                         : 1;
108
109                     bf16_p_.ws_elements_per_thread_ = bf16_p_.ws_cvt_elements_per_thread_
110                         + bf16_p_.ws_acc_elements_per_thread_;
111                     size_t bf16cvt_buf_sz_ = sizeof(acc_data_t) * bf16_p_.ws_elements_per_thread_
112                         * mkldnn_get_max_threads();
113                     auto scratchpad = scratchpad_registry().registrar();
114                     scratchpad.book(memory_tracking::names::key_sum_bf16cvt, bf16cvt_buf_sz_);
115                 }
116             }
117     };
118
119     simple_sum_t(const pd_t *apd, const input_vector &inputs,
120             const output_vector &outputs)
121         : cpu_primitive_t(apd, inputs, outputs) {
122     }
123
124     virtual void execute(event_t *e) const {
125         execute();
126         e->set_state(event_t::ready);
127     }
128
129     enum {max_num_arrs = 16 };
130     typedef typename prec_traits<src_data_type>::type src_data_t;
131     typedef typename prec_traits<dst_data_type>::type dst_data_t;
132     typedef typename prec_traits<data_type::f32>::type acc_data_t;
133
134 private:
135     void execute() const;
136     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
137 };
138
139 }
140 }
141 }
142
143 #endif
144
145 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s