updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / nchw_pooling.hpp
1 /*******************************************************************************
2 * Copyright 2017-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_NCHW_POOLING_HPP
18 #define CPU_NCHW_POOLING_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "cpu_pooling_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27 #include "bfloat16_utils.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::memory_format;
34
35 template <data_type_t d_type>
36 struct nchw_pooling_fwd_t: public cpu_primitive_t {
37     struct pd_t: public cpu_pooling_fwd_pd_t {
38         pd_t(engine_t *engine, const pooling_desc_t *adesc,
39                 const primitive_attr_t *attr,
40                 const pooling_fwd_pd_t *hint_fwd_pd)
41             : cpu_pooling_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
42
43         DECLARE_COMMON_PD_T("nchw_pooling:any", nchw_pooling_fwd_t);
44
45         virtual status_t init() override {
46             using namespace prop_kind;
47             using namespace alg_kind;
48             assert(engine()->kind() == engine_kind::cpu);
49             auto src_format = src_pd()->desc()->format;
50             bool ok = true
51                 && set_default_params() == status::success
52                 && utils::one_of(desc()->prop_kind, forward_training,
53                         forward_inference)
54                 && utils::one_of(desc()->alg_kind, pooling_max,
55                         pooling_avg_include_padding,
56                         pooling_avg_exclude_padding)
57                 && !has_zero_dim_memory()
58                 && utils::everyone_is(d_type, src_pd()->desc()->data_type,
59                         dst_pd()->desc()->data_type)
60                 && utils::one_of(src_format, nchw, ncdhw)
61                 && (src_format == dst_pd()->desc()->format)
62                 && attr()->has_default_values();
63             if (!ok) return status::unimplemented;
64
65             bool is_training = desc_.prop_kind == forward_training;
66             if (desc()->alg_kind == pooling_max && is_training) {
67                 auto indices_desc = *dst_pd()->desc();
68                 indices_desc.data_type = pooling_index_data_type(desc());
69                 ws_pd_ = cpu_memory_t::pd_t(engine_, &indices_desc);
70             }
71
72             init_scratchpad();
73
74             return status::success;
75         }
76
77         private:
78             void init_scratchpad() {
79                 using namespace memory_tracking::names;
80                 if (src_pd()->desc()->data_type == data_type::bf16) {
81                     size_t src_sz_ = ID() * IH() * IW() * C() * MB();
82                     auto scratchpad = scratchpad_registry().registrar();
83                     scratchpad.book(key_pool_src_bf16cvt, sizeof(float) * src_sz_);
84                 }
85             }
86     };
87
88     nchw_pooling_fwd_t(const pd_t *apd, const input_vector &inputs,
89             const output_vector &outputs)
90         : cpu_primitive_t(apd, inputs, outputs) {}
91
92     ~nchw_pooling_fwd_t() {}
93
94     typedef typename prec_traits<d_type>::type data_t;
95
96     virtual void execute(event_t *e) const {
97         execute_forward();
98         e->set_state(event_t::ready);
99     }
100
101 private:
102     void execute_forward() const;
103     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
104 };
105
106 template <data_type_t d_type>
107 struct nchw_pooling_bwd_t: public cpu_primitive_t {
108     struct pd_t: public cpu_pooling_bwd_pd_t {
109         pd_t(engine_t *engine, const pooling_desc_t *adesc,
110                 const primitive_attr_t *attr,
111                 const pooling_fwd_pd_t *hint_fwd_pd)
112             : cpu_pooling_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
113
114         DECLARE_COMMON_PD_T("nchw:any", nchw_pooling_bwd_t);
115
116         virtual status_t init() override {
117             using namespace prop_kind;
118             using namespace alg_kind;
119             assert(engine()->kind() == engine_kind::cpu);
120             auto diff_dst_format = diff_dst_pd()->desc()->format;
121             bool ok = true
122                 && set_default_params() == status::success
123                 && utils::one_of(desc()->prop_kind, backward_data)
124                 && utils::one_of(desc()->alg_kind, pooling_max,
125                         pooling_avg_include_padding,
126                         pooling_avg_exclude_padding)
127                 && !has_zero_dim_memory()
128                 && utils::everyone_is(d_type,
129                         diff_dst_pd()->desc()->data_type,
130                         diff_src_pd()->desc()->data_type)
131                 && utils::one_of(diff_dst_format, nchw, ncdhw)
132                 && (diff_dst_format == diff_src_pd()->desc()->format)
133                 && attr()->has_default_values();
134             if (!ok) return status::unimplemented;
135
136             if (desc()->alg_kind == pooling_max) {
137                 bool ws_ok = true
138                     && hint_fwd_pd_
139                     && hint_fwd_pd_->workspace_pd()
140                     && utils::one_of(
141                             hint_fwd_pd_->workspace_pd()->desc()->format,
142                             nchw, nChw8c, nChw16c, ncdhw, nCdhw8c, nCdhw16c);
143                 if (!ws_ok) return status::unimplemented;
144
145                 ws_pd_ = *(cpu_memory_t::pd_t*)hint_fwd_pd_->workspace_pd();
146             }
147
148             init_scratchpad();
149
150             return status::success;
151         }
152
153         private:
154             void init_scratchpad() {
155                 using namespace memory_tracking::names;
156                 if (diff_src_pd()->desc()->data_type == data_type::bf16) {
157                     size_t dst_sz_ = OD() * OH() * OW();
158                     size_t src_sz_ = ID() * IH() * IW();
159                     size_t nthrs = mkldnn_get_max_threads();
160                     auto scratchpad = scratchpad_registry().registrar();
161                     scratchpad.book(key_pool_src_bf16cvt,
162                             sizeof(float) * src_sz_ * nthrs);
163                     scratchpad.book(key_pool_dst_bf16cvt,
164                             sizeof(float) * dst_sz_ * nthrs);
165                 }
166             }
167     };
168
169     nchw_pooling_bwd_t(const pd_t *apd, const input_vector &inputs,
170             const output_vector &outputs)
171         : cpu_primitive_t(apd, inputs, outputs) {}
172     ~nchw_pooling_bwd_t() {}
173
174     typedef typename prec_traits<d_type>::type data_t;
175
176     virtual void execute(event_t *e) const {
177         execute_backward();
178         e->set_state(event_t::ready);
179     }
180
181 private:
182     void execute_backward() const;
183     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
184 };
185
186 }
187 }
188 }
189
190 #endif
191
192 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s