Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / wino_reorder.hpp
1 /*******************************************************************************
2  * Copyright 2017-2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #ifndef CPU_WINO_REORDER_HPP
18 #define CPU_WINO_REORDER_HPP
19
20 namespace mkldnn {
21 namespace impl {
22 namespace cpu {
23
24 template <data_type_t type_i, data_type_t type_o>
25 struct wino_reorder_t : public cpu_primitive_t {
26     struct pd_t : public cpu_reorder_pd_t {
27         pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
28                 const primitive_attr_t *attr)
29             : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
30
31         DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t);
32
33         static status_t create(reorder_pd_t **reorder_pd,
34                 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
35                 const primitive_attr_t *attr) {
36             assert(input_pd->engine()->kind() == engine_kind::cpu);
37             assert(output_pd->engine()->kind() == engine_kind::cpu);
38
39             const memory_desc_wrapper id(input_pd), od(output_pd);
40             bool args_ok = true
41                 && id.data_type() == type_i
42                 && od.data_type() == type_o
43                 && utils::one_of(id.format(), goihw, oihw)
44                 && od.format() == wino_fmt
45                 && one_of(od.wino_desc().wino_format,
46                         mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio,
47                         mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio);
48             if (!args_ok) return status::invalid_arguments;
49
50             auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
51                     (const cpu_memory_pd_t *)output_pd, attr);
52             if (_pd == nullptr) return out_of_memory;
53             if (_pd->init() != success) { delete _pd; return unimplemented; }
54             return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
55         }
56
57         virtual status_t init() override {
58             status_t status = cpu_reorder_pd_t::init();
59             if (status != status::success) return status;
60
61             init_scratchpad();
62
63             return status::success;
64         }
65
66     private:
67         void init_scratchpad() {
68             auto &o = memory_desc_wrapper(output_pd()).wino_desc();
69             size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block;
70             size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic;
71
72             using namespace memory_tracking::names;
73             auto scratchpad = scratchpad_registry().registrar();
74             scratchpad.book(key_reorder_wino_transform_space,
75                     sizeof(in_data_t) * transform_space_size);
76             scratchpad.book(key_reorder_wino_plain,
77                     sizeof(out_data_t) * plain_size);
78         }
79     };
80
81 private:
82     typedef typename prec_traits<type_i>::type in_data_t;
83     typedef typename prec_traits<type_o>::type out_data_t;
84     const int unsign_val_in_wino_domain_ = 5;
85
86     wino_reorder_t(const pd_t *apd, const input_vector &inputs,
87             const output_vector &outputs)
88         : cpu_primitive_t(apd, inputs, outputs)
89     {
90         const memory_desc_wrapper input_d(pd()->input_pd());
91         const memory_desc_wrapper output_d(pd()->output_pd());
92
93         r_ = output_d.wino_desc().r;
94         w_alpha_ = output_d.wino_desc().alpha;
95         wino_format_ = output_d.wino_desc().wino_format;
96
97         const auto &in_dims = input_d.dims();
98         int groups;
99         int groups_offset;
100         if (input_d.format() == goihw) {
101             groups = in_dims[0];
102             groups_offset = 1;
103         } else {
104             groups = 1;
105             groups_offset = 0;
106         }
107         assert(groups == 1); // groups are not supported now
108         MAYBE_UNUSED(groups);
109
110         or_oc_ = in_dims[0 + groups_offset];
111         or_ic_ = in_dims[1 + groups_offset];
112         kh_ = in_dims[2 + groups_offset];
113         kw_ = in_dims[3 + groups_offset];
114
115         oc_ = output_d.wino_desc().oc;
116         ic_ = output_d.wino_desc().ic;
117         oc_block_ = output_d.wino_desc().oc_block;
118         ic_block_ = output_d.wino_desc().ic_block;
119         assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0);
120         nb_oc_ = oc_ / oc_block_;
121         nb_ic_ = ic_ / ic_block_;
122         ic2_block_ = 1;
123         if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio)
124             ic2_block_ = output_d.wino_desc().ic2_block;
125         oc2_block_ = output_d.wino_desc().oc2_block;
126         assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0);
127
128         adj_scale_ = output_d.wino_desc().adj_scale;
129
130         size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_;
131         size_wspace_ = r_ * w_alpha_ * oc_block_;
132     }
133
134     void transform(out_data_t *__restrict tmp_wei,
135             const in_data_t *__restrict input,
136             in_data_t *__restrict wspace) const {
137         const memory_desc_wrapper input_d(pd()->input_pd()->desc());
138
139         round_mode_t rmode = pd()->attr()->round_mode_;
140         const int smask = pd()->attr()->output_scales_.mask_;
141         const int ndims_mask = math::ilog2q(smask + 1);
142         const size_t D_mask = utils::array_product(input_d.dims(), ndims_mask);
143         const float *__restrict scales = pd()->attr()->output_scales_.scales_;
144         assert(D_mask == 1 || D_mask == (size_t)oc_);
145
146         /* transform weights to winograd domain */
147         const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 },
148             { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } };
149
150         const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f },
151             { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f },
152             { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f },
153             { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f },
154             { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f },
155             { 0.f, 0.f, 1.f } };
156
157         float *__restrict g;
158         if (one_of(wino_format_, mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio,
159             mkldnn_wino_wei_aaOBiOo))
160             g = (float *)G_2x2_3x3;
161         else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio)
162             g = (float *)G_4x4_3x3;
163         else {
164             assert("Unknown winograd weights target layout");
165             return;
166         }
167
168         int Z = oc_ * ic_;
169         assert(r_ == kh_ && r_ == kw_);
170
171         for (int iic = 0; iic < ic_; iic++) {
172         for (int ob = 0; ob < nb_oc_; ob++) {
173             const in_data_t *__restrict _inp
174                     = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_;
175             out_data_t *__restrict _out
176                     = tmp_wei + (iic * nb_oc_ + ob) * oc_block_;
177
178             parallel_nd(size_wspace_, [&](int i) { wspace[i] = 0.f; });
179
180             parallel_nd(r_, w_alpha_, oc_block_,
181                 [&](int ih, int j, int ioc) {
182                 for (int iw = 0; iw < r_; ++iw) {
183                     int inp_oc = ob * oc_block_ + ioc;
184                     int inp_ic = iic;
185                     in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_)
186                         ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw]
187                         : 0.f;
188                     wspace[(ih * w_alpha_ + j) * oc_block_ + ioc]
189                             += inp_v * g[j * r_ + iw];
190                 }
191             });
192
193             parallel_nd(w_alpha_, w_alpha_, oc_block_,
194                 [&](int i, int j, int ioc) {
195                 float t = 0;
196                 for (int k = 0; k < r_; ++k)
197                     t += g[i * r_ + k]
198                             * wspace[(k * w_alpha_ + j) * oc_block_ + ioc];
199                 if (type_o == s8) {
200                     const float scale = (D_mask == 1)
201                         ? scales[0]
202                         : scales[ob * oc_block_ + ioc];
203                     _out[(i * w_alpha_ + j) * Z + ioc]
204                             = qz_b0<in_data_t, out_data_t>()(
205                                     (in_data_t)t, scale * adj_scale_, rmode);
206                 } else {
207                     _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t;
208                 }
209             });
210         }}
211     }
212
213     void reorder_to_aaOIoi(out_data_t *__restrict output,
214             const out_data_t *__restrict tmp_wei) const {
215         int32_t *__restrict dst_bias = nullptr;
216         if (type_o == s8) {
217             const auto bias_shift = sizeof(out_data_t) * size_wino_wei_;
218             const size_t bias_size = w_alpha_ * w_alpha_ * oc_;
219
220             dst_bias = (int32_t *)(output + bias_shift);
221             utils::array_set((int32_t *)dst_bias, 0, bias_size);
222         }
223         int index = 0;
224         for (int u_h = 0; u_h < w_alpha_; u_h++) {
225         for (int u_w = 0; u_w < w_alpha_; u_w++) {
226             parallel_nd(nb_oc_, oc_block_, [&](int ob, int o) {
227                 int u_h_shift = u_h * w_alpha_ * ic_ * oc_;
228                 int u_w_shift = u_w * ic_ * oc_;
229                 int u_h_shift_b = u_h * w_alpha_ * oc_;
230                 int u_w_shift_b = u_w * oc_;
231                 int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_;
232                 for (int ib = 0; ib < nb_ic_; ib++) {
233                 for (int i = 0; i < ic_block_; i++) {
234                     int _i = ib * ic_block_;
235                     int _o = ob * oc_block_;
236                     int ic_shift = (_i + i) * oc_;
237                     int oc_shift = (_o + o);
238                     int ic_block_shift = ib * oc_block_ * ic_block_ + i;
239                     int src_offset =
240                             u_h_shift + u_w_shift + ic_shift + oc_shift;
241                     int dst_offset = u_h_shift + u_w_shift + oc_block_shift
242                             + ic_block_shift;
243
244                     output[dst_offset] = tmp_wei[src_offset];
245                     if (type_o == s8) {
246                         int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift;
247                         if (index != unsign_val_in_wino_domain_)
248                             dst_bias[bias_offset]
249                                     -= (128 * (int32_t)output[dst_offset]);
250                         else
251                             dst_bias[bias_offset] = 0;
252                     }
253                 }}
254             });
255             index++;
256         }}
257     }
258
259     void reorder_to_aaOio(out_data_t *__restrict output,
260             const out_data_t *__restrict tmp_wei) const {
261         parallel_nd(w_alpha_, w_alpha_, nb_oc_,
262             [&](int u_h, int u_w, int ob) {
263             for (int ib = 0; ib < nb_ic_; ib++) {
264             for (int i = 0; i < ic_block_; i++) {
265             for (int o = 0; o < oc_block_; o++) {
266                 int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_
267                     + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o);
268
269                 int dst_offset
270                     = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_
271                     + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_
272                     + ob * nb_ic_ * ic_block_ * oc_block_
273                     + ib * ic_block_ * oc_block_ + i * oc_block_ + o;
274                 output[dst_offset] = tmp_wei[src_offset];
275             }}}
276         });
277     }
278
279     void reorder_to_aaOBiOo(out_data_t *__restrict output,
280             const out_data_t *__restrict tmp_wei) const {
281         int oc_chunks = nb_oc_ / oc2_block_;
282
283         parallel_nd(w_alpha_, w_alpha_, oc_chunks,
284             [&](int u_h, int u_w, int occ) {
285             for (int ib = 0; ib < nb_ic_; ib++) {
286                 out_data_t *__restrict wei_ptr = output
287                     + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib)
288                     * oc2_block_ * ic_block_ * oc_block_;
289                 int wei_offset = 0;
290                 for (int i = 0; i < ic_block_; i++) {
291                 for (int ob2 = 0; ob2 < oc2_block_; ob2++) {
292                     for (int o = 0; o < oc_block_; o++) {
293                         int icp = ib * ic_block_ + i;
294                         int ocp =
295                             occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o;
296
297                         int src_offset = u_h * w_alpha_ * ic_ * oc_
298                             + u_w * ic_ * oc_ + icp * oc_ + ocp;
299                         wei_ptr[wei_offset + o] = tmp_wei[src_offset];
300                     }
301                     wei_offset += oc_block_;
302                 }}
303             }
304         });
305     }
306
307     void reorder_to_OBaaIBOIio(out_data_t *__restrict output,
308             const out_data_t *__restrict tmp_wei) const {
309         int ic_chunks = nb_ic_ / ic2_block_;
310         int oc_chunks = nb_oc_ / oc2_block_;
311
312         parallel_nd(oc_chunks, w_alpha_, w_alpha_,
313             [&](int occ, int u_h, int u_w) {
314             for (int icc = 0; icc < ic_chunks; icc++) {
315             for (int ob = 0; ob < oc2_block_; ob++) {
316                 int ocp = (occ * oc2_block_ + ob) * oc_block_;
317                 for (int ib = 0; ib < ic2_block_; ib++) {
318                 for (int i = 0; i < ic_block_; i++) {
319                     int icp = (icc * ic2_block_ + ib) * ic_block_ + i;
320
321                     int src_offset = u_h * w_alpha_ * ic_ * oc_
322                         + u_w * ic_ * oc_ + icp * oc_ + ocp;
323                     int wei_offset
324                         = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w)
325                             * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_
326                             + ib) * ic_block_ + i) * oc_block_;
327                     for (int o = 0; o < oc_block_; o++)
328                         output[wei_offset + o] = tmp_wei[src_offset + o];
329                 }}
330             }}
331         });
332     }
333
334     virtual void execute(event_t *e) const {
335         auto input = reinterpret_cast<const in_data_t *>(input_memory(0));
336         auto output = reinterpret_cast<out_data_t *>(memory());
337
338         auto wspace = (in_data_t *__restrict)scratchpad().template get<void>(
339                 memory_tracking::names::key_reorder_wino_transform_space);
340         auto tmp_wei = (out_data_t *__restrict)scratchpad().template get<void>(
341                 memory_tracking::names::key_reorder_wino_plain);
342
343         transform(tmp_wei, input, wspace);
344
345         /* reorder to winograd domain */
346         switch (wino_format_) {
347         case mkldnn_wino_wei_aaOIoi:
348             reorder_to_aaOIoi(output, tmp_wei); break;
349         case mkldnn_wino_wei_aaOio:
350             reorder_to_aaOio(output, tmp_wei); break;
351         case mkldnn_wino_wei_aaOBiOo:
352             reorder_to_aaOBiOo(output, tmp_wei); break;
353         case mkldnn_wino_wei_OBaaIBOIio:
354             reorder_to_OBaaIBOIio(output, tmp_wei); break;
355         default: assert("Unknown wino format"); break;
356         }
357
358         e->set_state(event_t::ready);
359     }
360
361     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
362     int r_, w_alpha_;
363     int ic_, oc_, or_ic_, or_oc_, kh_, kw_;
364     int oc_block_, ic_block_, oc2_block_, ic2_block_;
365     float adj_scale_;
366     int nb_oc_, nb_ic_;
367     mkldnn_wino_memory_format_t wino_format_;
368     int size_wino_wei_;
369     int size_wspace_;
370 };
371
372 } // namespace cpu
373 } // namespace impl
374 } // namespace mkldnn
375
376 #endif