1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_NHWC_POOLING_HPP
18 #define CPU_NHWC_POOLING_HPP
22 #include "c_types_map.hpp"
23 #include "cpu_engine.hpp"
24 #include "cpu_pooling_pd.hpp"
25 #include "mkldnn_thread.hpp"
26 #include "type_helpers.hpp"
33 namespace nhwc_pooling {
34 size_t strided_offset(const int _n, const size_t _sn, const int _d,
35 const size_t _sd, const int _h, const size_t _sh, const int _w,
39 template <impl::data_type_t data_type>
40 struct nhwc_pooling_fwd_t: public cpu_primitive_t {
41 struct pd_t: public cpu_pooling_fwd_pd_t {
42 pd_t(engine_t *engine, const pooling_desc_t *adesc,
43 const primitive_attr_t *attr,
44 const pooling_fwd_pd_t *hint_fwd_pd)
45 : cpu_pooling_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
47 DECLARE_COMMON_PD_T("nhwc_pooling:any", nhwc_pooling_fwd_t);
49 virtual status_t init() override {
50 using namespace prop_kind;
51 using namespace alg_kind;
52 using namespace memory_format;
53 assert(engine()->kind() == engine_kind::cpu);
54 auto src_format = src_pd()->desc()->format;
56 && set_default_params() == status::success
57 && utils::one_of(desc()->prop_kind, forward_training,
59 && utils::one_of(desc()->alg_kind, pooling_max,
60 pooling_avg_include_padding,
61 pooling_avg_exclude_padding)
62 && utils::everyone_is(data_type,
63 src_pd()->desc()->data_type,
64 dst_pd()->desc()->data_type)
65 && utils::one_of(src_format, nhwc, ndhwc)
66 && (src_format == dst_pd()->desc()->format)
67 && attr()->has_default_values();
68 if (!ok) return status::unimplemented;
70 bool is_training = desc_.prop_kind == forward_training;
71 if (desc()->alg_kind == pooling_max && is_training) {
72 // Allocate dense workspace buffer based on logical dimensions
74 memory_desc_t indices_desc;
76 dims_t ws_dims = { MB(), C(), OD(), OH(), OW() };
77 mkldnn_memory_desc_init(&indices_desc, 5, ws_dims,
78 pooling_index_data_type(desc()),
79 memory_format::ndhwc);
81 dims_t ws_dims = { MB(), C(), OH(), OW() };
82 mkldnn_memory_desc_init(&indices_desc, 4, ws_dims,
83 pooling_index_data_type(desc()),
86 ws_pd_ = cpu_memory_t::pd_t(engine_, &indices_desc);
89 return status::success;
93 nhwc_pooling_fwd_t(const pd_t *apd, const input_vector &inputs,
94 const output_vector &outputs)
95 : cpu_primitive_t(apd, inputs, outputs) {}
97 typedef typename prec_traits<data_type>::type data_t;
99 virtual void execute(event_t *e) const {
101 e->set_state(event_t::ready);
105 void execute_forward() const;
106 void array_div_by_const(const int n, const data_t *src, const size_t num,
108 void array_add(const int n, const data_t *src, data_t *dst) const;
110 template <bool use_workspace>
111 void array_nhwc_max(const int n, data_t *dst, const data_t *src,
112 unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt,
113 const int index) const {
114 assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists
116 for (int oc = 0; oc < n; ++oc) {
120 // update index of maximum
121 #if defined __INTEL_COMPILER
122 if ((use_workspace) && (s > mv)) {
123 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
124 if (ws_dt == data_type::u8) {
125 assert(0 <= index && index <= 255);
126 ws[ws_offset + oc] = index;
128 reinterpret_cast<int *>(ws)[ws_offset + oc] = index;
131 // Need to add explicit predicates for GCC to vectorize this.
132 // And although the resulting code is ugly, it is still 4 times
133 // faster than scalar
135 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
137 if (ws_dt == data_type::u8) {
138 assert(0 <= index && index <= 255);
139 unsigned char predicate = (s > mv) ? 0xff : 0;
140 unsigned char current_value = ws[ws_offset + oc];
141 current_value = (predicate & (unsigned char)index)
142 | ((~predicate) & current_value);
143 ws[ws_offset + oc] = current_value;
145 auto wint = reinterpret_cast<int *>(ws);
146 unsigned int predicate = (s > mv) ? 0xffffffff : 0;
147 unsigned int current_value = wint[ws_offset + oc];
148 current_value = (predicate & (unsigned int)index)
149 | ((~predicate) & current_value);
150 wint[ws_offset + oc] = current_value;
155 dst[oc] = nstl::max(s, mv);
159 template <bool use_workspace>
160 void array_nhwc_initialize(const int n, data_t *dst, unsigned char *ws,
161 const size_t ws_offset, const data_type_t ws_dt) const {
162 assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists
163 for (int oc = 0; oc < n; ++oc) {
165 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
166 if (ws_dt == data_type::u8) {
167 ws[ws_offset + oc] = 0;
169 reinterpret_cast<int *>(ws)[ws_offset + oc] = 0;
171 dst[oc] = nstl::numeric_limits<data_t>::lowest();
175 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
178 template <impl::data_type_t data_type>
179 struct nhwc_pooling_bwd_t: public cpu_primitive_t {
180 struct pd_t: public cpu_pooling_bwd_pd_t {
181 pd_t(engine_t *engine, const pooling_desc_t *adesc,
182 const primitive_attr_t *attr,
183 const pooling_fwd_pd_t *hint_fwd_pd)
184 : cpu_pooling_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
186 DECLARE_COMMON_PD_T("nhwc:any", nhwc_pooling_bwd_t);
188 virtual status_t init() override {
189 using namespace prop_kind;
190 using namespace alg_kind;
191 using namespace memory_format;
192 assert(engine()->kind() == engine_kind::cpu);
193 auto diff_dst_format = diff_dst_pd()->desc()->format;
195 && set_default_params() == status::success
196 && utils::one_of(desc()->prop_kind, backward_data)
197 && utils::one_of(desc()->alg_kind, pooling_max,
198 pooling_avg_include_padding,
199 pooling_avg_exclude_padding)
200 && utils::everyone_is(data_type,
201 diff_dst_pd()->desc()->data_type,
202 diff_src_pd()->desc()->data_type)
203 && utils::one_of(diff_dst_format, nhwc, ndhwc)
204 && (diff_dst_format == diff_src_pd()->desc()->format)
205 && attr()->has_default_values();
207 return status::unimplemented;
209 if (desc()->alg_kind == pooling_max) {
212 && hint_fwd_pd_->workspace_pd()
214 hint_fwd_pd_->workspace_pd()->desc()->format,
216 && hint_fwd_pd_->workspace_pd()->engine()->kind()
218 if (!ws_ok) return status::unimplemented;
220 ws_pd_ = *(cpu_memory_t::pd_t *)hint_fwd_pd_->workspace_pd();
223 return status::success;
227 nhwc_pooling_bwd_t(const pd_t *apd, const input_vector &inputs,
228 const output_vector &outputs)
229 : cpu_primitive_t(apd, inputs, outputs) {}
230 typedef typename prec_traits<data_type>::type data_t;
232 virtual void execute(event_t *e) const {
234 e->set_state(event_t::ready);
238 void execute_backward() const;
239 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
248 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s