1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
18 #include "mkldnn_types.h"
20 #include "c_types_map.hpp"
21 #include "memory_desc_wrapper.hpp"
22 #include "memory_pd.hpp"
23 #include "type_helpers.hpp"
29 memory_desc_wrapper::memory_desc_wrapper(const memory_pd_t *m_pd)
30 : _md(m_pd == nullptr ? nullptr : m_pd->desc()) {}
33 using namespace mkldnn::impl::utils;
34 using namespace mkldnn::impl::status;
35 using namespace mkldnn::impl::memory_format;
37 status_t fill_x(memory_desc_t &md) {
38 const int ndims = md.ndims;
39 if (ndims != 1) return invalid_arguments;
40 blocking_desc_t &blk = md.layout_desc.blocking;
41 array_set(blk.block_dims, 1, ndims);
42 array_set(blk.strides[1], 1, ndims);
43 blk.strides[0][0] = 1;
44 array_copy(blk.padding_dims, md.dims, ndims);
45 array_set(blk.offset_padding_to_data, 0, ndims);
46 blk.offset_padding = 0;
50 /* TODO: improve me maybe... and put this to utils */
51 inline void set_default_strides(strides_t strides, const dims_t dims,
52 int ndims, const int *perm = NULL) {
53 int id_perm[TENSOR_MAX_DIMS] = {0};
54 for (int i = 0; i < ndims; ++i)
59 strides[perm[ndims - 1]] = 1;
60 for (int d = 1; d < ndims; ++d) {
61 const int prev_idx = perm[ndims - d];
62 const int curr_idx = perm[ndims - 1 - d];
64 strides[curr_idx] = dims[curr_idx] == 0
66 : strides[prev_idx] * nstl::max((ptrdiff_t)1, dims[prev_idx]);
70 status_t fill_nonblocked(memory_desc_t &md, const int perm[]) {
71 const int ndims = md.ndims;
72 blocking_desc_t &blk = md.layout_desc.blocking;
73 array_set(blk.block_dims, 1, ndims);
74 array_set(blk.strides[1], 1, ndims);
76 if (md.format == mkldnn_nhwc && md.data_type == mkldnn_bin) {
79 const dims_t block_dims = {1, 8, 1, 1};
80 for (int d = 0; d < ndims; ++d) {
81 padding_dims[d] = rnd_up(md.dims[d], block_dims[d]);
84 set_default_strides(blk.strides[0], padding_dims, ndims, perm);
85 array_copy(blk.padding_dims, padding_dims, ndims);
88 set_default_strides(blk.strides[0], md.dims, ndims, perm);
89 array_copy(blk.padding_dims, md.dims, ndims);
92 array_set(blk.offset_padding_to_data, 0, ndims);
93 blk.offset_padding = 0;
98 status_t fill_contiguous_blocked(memory_desc_t &md, const dims_t block_dims,
100 const int ndims = md.ndims;
102 blocking_desc_t &blk = md.layout_desc.blocking;
103 array_copy(blk.block_dims, block_dims, ndims);
105 dim_t unrolled_dims[2*TENSOR_MAX_DIMS];
106 stride_t unrolled_strides[2*TENSOR_MAX_DIMS];
109 for (int d = 0; d < ndims; ++d) {
110 unrolled_dims[d] = div_up(md.dims[d], block_dims[d]);
111 unrolled_dims[ndims + d] = block_dims[d];
112 padding_dims[d] = rnd_up(md.dims[d], block_dims[d]);
115 set_default_strides(unrolled_strides, unrolled_dims, 2*ndims, perm);
116 array_copy(blk.strides[0], &unrolled_strides[0], ndims);
117 array_copy(blk.strides[1], &unrolled_strides[ndims], ndims);
118 array_copy(blk.padding_dims, padding_dims, ndims);
119 array_set(blk.offset_padding_to_data, 0, ndims);
120 blk.offset_padding = 0;
124 status_t fill_nc(memory_desc_t &md) {
125 if (md.ndims != 2) return invalid_arguments;
127 const int perm[2] = {0, 1};
128 return fill_nonblocked(md, perm);
131 status_t fill_ncw(memory_desc_t &md) {
132 if (md.ndims != 3) return invalid_arguments;
134 const int perm[3] = {0, 1, 2};
135 return fill_nonblocked(md, perm);
138 status_t fill_nwc(memory_desc_t &md) {
139 if (md.ndims != 3) return invalid_arguments;
141 const int perm[3] = {0, 2, 1};
142 return fill_nonblocked(md, perm);
145 status_t fill_nCw4c(memory_desc_t &md) {
146 if (md.ndims != 3) return invalid_arguments;
148 const dims_t block_dims = { 1, 4, 1 };
152 return fill_contiguous_blocked(md, block_dims, perm);
156 status_t fill_nCw8c(memory_desc_t &md) {
157 if (md.ndims != 3) return invalid_arguments;
159 const dims_t block_dims = {1, 8, 1, 1};
163 return fill_contiguous_blocked(md, block_dims, perm);
166 status_t fill_nCw16c(memory_desc_t &md) {
167 if (md.ndims != 3) return invalid_arguments;
169 const dims_t block_dims = {1, 16, 1};
173 return fill_contiguous_blocked(md, block_dims, perm);
176 status_t fill_nchw(memory_desc_t &md) {
177 if (md.ndims != 4) return invalid_arguments;
179 const int perm[4] = {0, 1, 2, 3};
180 return fill_nonblocked(md, perm);
183 status_t fill_ncdhw(memory_desc_t &md) {
184 if (md.ndims != 5) return invalid_arguments;
186 const int perm[5] = {0, 1, 2, 3, 4};
187 return fill_nonblocked(md, perm);
190 status_t fill_oidhw(memory_desc_t &md) {
191 if (md.ndims != 5) return invalid_arguments;
193 const int perm[5] = {0, 1, 2, 3, 4};
194 return fill_nonblocked(md, perm);
197 status_t fill_goidhw(memory_desc_t &md) {
198 if (md.ndims != 6) return invalid_arguments;
200 const int perm[6] = {0, 1, 2, 3, 4, 5};
201 return fill_nonblocked(md, perm);
204 status_t fill_nhwc(memory_desc_t &md) {
205 if (md.ndims != 4) return invalid_arguments;
207 const int perm[4] = {0, 2, 3, 1};
208 return fill_nonblocked(md, perm);
211 status_t fill_ndhwc(memory_desc_t &md) {
212 if (md.ndims != 5) return invalid_arguments;
214 const int perm[5] = {0, 2, 3, 4, 1};
215 return fill_nonblocked(md, perm);
218 status_t fill_chwn(memory_desc_t &md) {
219 if (md.ndims != 4) return invalid_arguments;
221 const int perm[4] = {1, 2, 3, 0};
222 return fill_nonblocked(md, perm);
225 status_t fill_nChw4c(memory_desc_t &md) {
226 if (md.ndims != 4) return invalid_arguments;
228 const dims_t block_dims = { 1, 4, 1, 1 };
232 return fill_contiguous_blocked(md, block_dims, perm);
235 status_t fill_nChw8c(memory_desc_t &md) {
236 if (md.ndims != 4) return invalid_arguments;
238 const dims_t block_dims = {1, 8, 1, 1};
242 return fill_contiguous_blocked(md, block_dims, perm);
245 status_t fill_nChw16c(memory_desc_t &md) {
246 if (md.ndims != 4) return invalid_arguments;
248 const dims_t block_dims = {1, 16, 1, 1};
252 return fill_contiguous_blocked(md, block_dims, perm);
255 status_t fill_nCdhw16c(memory_desc_t &md) {
256 if (md.ndims != 5) return invalid_arguments;
258 const dims_t block_dims = {1, 16, 1, 1, 1};
262 return fill_contiguous_blocked(md, block_dims, perm);
265 status_t fill_nCdhw4c(memory_desc_t &md) {
266 if (md.ndims != 5) return invalid_arguments;
268 const dims_t block_dims = { 1, 4, 1, 1, 1 };
272 return fill_contiguous_blocked(md, block_dims, perm);
275 status_t fill_nCdhw8c(memory_desc_t &md) {
276 if (md.ndims != 5) return invalid_arguments;
278 const dims_t block_dims = {1, 8, 1, 1, 1};
282 return fill_contiguous_blocked(md, block_dims, perm);
285 status_t fill_oi(memory_desc_t &md) {
286 if (md.ndims != 2) return invalid_arguments;
288 const int perm[2] = {0, 1};
289 return fill_nonblocked(md, perm);
292 status_t fill_io(memory_desc_t &md) {
293 if (md.ndims != 2) return invalid_arguments;
295 const int perm[2] = {1, 0};
296 return fill_nonblocked(md, perm);
299 status_t fill_oiw(memory_desc_t &md) {
300 if (md.ndims != 3) return invalid_arguments;
302 const int perm[3] = {0, 1, 2};
303 return fill_nonblocked(md, perm);
306 status_t fill_wio(memory_desc_t &md) {
307 if (md.ndims != 3) return invalid_arguments;
309 const int perm[3] = {2, 1, 0};
310 return fill_nonblocked(md, perm);
313 status_t fill_Owi4o(memory_desc_t &md) {
314 if (md.ndims != 3) return invalid_arguments;
316 const dims_t block_dims = { 4, 1, 1 };
320 return fill_contiguous_blocked(md, block_dims, perm);
323 status_t fill_Owi8o(memory_desc_t &md) {
324 if (md.ndims != 3) return invalid_arguments;
326 const dims_t block_dims = {8, 1, 1};
330 return fill_contiguous_blocked(md, block_dims, perm);
333 status_t fill_OIw8o8i(memory_desc_t &md) {
334 if (md.ndims != 3) return invalid_arguments;
336 const dims_t block_dims = {8, 8, 1};
340 return fill_contiguous_blocked(md, block_dims, perm);
343 status_t fill_OIw4i4o(memory_desc_t &md) {
344 if (md.ndims != 3) return invalid_arguments;
346 const dims_t block_dims = { 4, 4, 1 };
350 return fill_contiguous_blocked(md, block_dims, perm);
353 status_t fill_OIw8i8o(memory_desc_t &md) {
354 if (md.ndims != 3) return invalid_arguments;
356 const dims_t block_dims = {8, 8, 1};
360 return fill_contiguous_blocked(md, block_dims, perm);
363 status_t fill_OIw16i16o(memory_desc_t &md) {
364 if (md.ndims != 3) return invalid_arguments;
366 const dims_t block_dims = {16, 16, 1};
370 return fill_contiguous_blocked(md, block_dims, perm);
373 status_t fill_OIw16o16i(memory_desc_t &md) {
374 if (md.ndims != 3) return invalid_arguments;
376 const dims_t block_dims = {16, 16, 1};
380 return fill_contiguous_blocked(md, block_dims, perm);
383 status_t fill_Oiw4o(memory_desc_t &md) {
384 if (md.ndims != 3) return invalid_arguments;
386 const dims_t block_dims = {4, 1, 1};
390 return fill_contiguous_blocked(md, block_dims, perm);
393 status_t fill_Oiw16o(memory_desc_t &md) {
394 if (md.ndims != 3) return invalid_arguments;
396 const dims_t block_dims = { 16, 1, 1 };
400 return fill_contiguous_blocked(md, block_dims, perm);
403 status_t fill_Owi16o(memory_desc_t &md) {
404 if (md.ndims != 3) return invalid_arguments;
406 const dims_t block_dims = {16, 1, 1};
410 return fill_contiguous_blocked(md, block_dims, perm);
413 status_t fill_OIw8i16o2i(memory_desc_t &md) {
414 if (md.ndims != 3) return invalid_arguments;
416 const dims_t block_dims = {16, 16, 1};
420 return fill_contiguous_blocked(md, block_dims, perm);
423 status_t fill_IOw16o16i(memory_desc_t &md) {
424 if (md.ndims != 3) return invalid_arguments;
426 const dims_t block_dims = {16, 16, 1};
430 return fill_contiguous_blocked(md, block_dims, perm);
433 status_t fill_OIw8o16i2o(memory_desc_t &md) {
434 if (md.ndims != 3) return invalid_arguments;
436 const dims_t block_dims = {16, 16, 1};
440 return fill_contiguous_blocked(md, block_dims, perm);
443 status_t fill_oihw(memory_desc_t &md) {
444 if (md.ndims != 4) return invalid_arguments;
446 const int perm[4] = {0, 1, 2, 3};
447 return fill_nonblocked(md, perm);
450 status_t fill_ihwo(memory_desc_t &md) {
451 if (md.ndims != 4) return invalid_arguments;
453 const int perm[4] = {1, 2, 3, 0};
454 return fill_nonblocked(md, perm);
457 status_t fill_hwio(memory_desc_t &md) {
458 if (md.ndims != 4) return invalid_arguments;
460 const int perm[4] = {2, 3, 1, 0};
461 return fill_nonblocked(md, perm);
464 status_t fill_iohw(memory_desc_t &md) {
465 if (md.ndims != 4) return invalid_arguments;
467 const int perm[4] = {1, 0, 2, 3};
468 return fill_nonblocked(md, perm);
471 status_t fill_dhwio(memory_desc_t &md) {
472 if (md.ndims != 5) return invalid_arguments;
474 const int perm[5] = {2, 3, 4, 1, 0};
475 return fill_nonblocked(md, perm);
478 status_t fill_OIhw4i4o(memory_desc_t &md) {
479 if (md.ndims != 4) return invalid_arguments;
481 const dims_t block_dims = { 4, 4, 1, 1 };
485 return fill_contiguous_blocked(md, block_dims, perm);
488 status_t fill_OIhw8i8o(memory_desc_t &md) {
489 if (md.ndims != 4) return invalid_arguments;
491 const dims_t block_dims = {8, 8, 1, 1};
495 return fill_contiguous_blocked(md, block_dims, perm);
498 status_t fill_OIhw16i16o(memory_desc_t &md) {
499 if (md.ndims != 4) return invalid_arguments;
501 const dims_t block_dims = {16, 16, 1, 1};
505 return fill_contiguous_blocked(md, block_dims, perm);
508 status_t fill_OIdhw16i16o(memory_desc_t &md) {
509 if (md.ndims != 5) return invalid_arguments;
511 const dims_t block_dims = {16, 16, 1, 1, 1};
515 return fill_contiguous_blocked(md, block_dims, perm);
518 status_t fill_OIdhw4i4o(memory_desc_t &md) {
519 if (md.ndims != 5) return invalid_arguments;
521 const dims_t block_dims = { 4, 4, 1, 1, 1 };
525 return fill_contiguous_blocked(md, block_dims, perm);
528 status_t fill_OIdhw8i8o(memory_desc_t &md) {
529 if (md.ndims != 5) return invalid_arguments;
531 const dims_t block_dims = {8, 8, 1, 1, 1};
535 return fill_contiguous_blocked(md, block_dims, perm);
538 status_t fill_OIhw4i16o4i(memory_desc_t &md) {
539 if (md.ndims != 4) return invalid_arguments;
541 const dims_t block_dims = {16, 16, 1, 1};
545 return fill_contiguous_blocked(md, block_dims, perm);
548 status_t fill_OhIw8o4i(memory_desc_t &md) {
549 if (md.ndims != 4) return invalid_arguments;
551 const dims_t block_dims = {8, 4, 1, 1};
555 return fill_contiguous_blocked(md, block_dims, perm);
558 status_t fill_OhIw8o32i(memory_desc_t &md) {
559 if (md.ndims != 4) return invalid_arguments;
561 const dims_t block_dims = {8, 32, 1, 1};
565 return fill_contiguous_blocked(md, block_dims, perm);
568 status_t fill_OhIw16o32i(memory_desc_t &md) {
569 if (md.ndims != 4) return invalid_arguments;
571 const dims_t block_dims = {16, 32, 1, 1};
575 return fill_contiguous_blocked(md, block_dims, perm);
578 status_t fill_OIhw8i16o2i(memory_desc_t &md) {
579 if (md.ndims != 4) return invalid_arguments;
581 const dims_t block_dims = {16, 16, 1, 1};
585 return fill_contiguous_blocked(md, block_dims, perm);
588 status_t fill_OIdhw8i16o2i(memory_desc_t &md) {
589 if (md.ndims != 5) return invalid_arguments;
591 const dims_t block_dims = {16, 16, 1, 1, 1};
595 return fill_contiguous_blocked(md, block_dims, perm);
598 status_t fill_OIhw8o8i(memory_desc_t &md) {
599 if (md.ndims != 4) return invalid_arguments;
601 const dims_t block_dims = {8, 8, 1, 1};
605 return fill_contiguous_blocked(md, block_dims, perm);
608 status_t fill_OIhw16o16i(memory_desc_t &md) {
609 if (md.ndims != 4) return invalid_arguments;
611 const dims_t block_dims = {16, 16, 1, 1};
615 return fill_contiguous_blocked(md, block_dims, perm);
618 status_t fill_OIdhw16o16i(memory_desc_t &md) {
619 if (md.ndims != 5) return invalid_arguments;
621 const dims_t block_dims = {16, 16, 1, 1, 1};
625 return fill_contiguous_blocked(md, block_dims, perm);
628 status_t fill_OIdhw8o8i(memory_desc_t &md) {
629 if (md.ndims != 5) return invalid_arguments;
631 const dims_t block_dims = {8, 8, 1, 1, 1};
635 return fill_contiguous_blocked(md, block_dims, perm);
638 status_t fill_IOhw16o16i(memory_desc_t &md) {
639 if (md.ndims != 4) return invalid_arguments;
641 const dims_t block_dims = {16, 16, 1, 1};
645 return fill_contiguous_blocked(md, block_dims, perm);
648 status_t fill_OIhw8o16i2o(memory_desc_t &md) {
649 if (md.ndims != 4) return invalid_arguments;
651 const dims_t block_dims = {16, 16, 1, 1};
655 return fill_contiguous_blocked(md, block_dims, perm);
658 status_t fill_Oihw4o(memory_desc_t &md) {
659 if (md.ndims != 4) return invalid_arguments;
661 const dims_t block_dims = {4, 1, 1, 1};
665 return fill_contiguous_blocked(md, block_dims, perm);
668 status_t fill_Oihw16o(memory_desc_t &md) {
669 if (md.ndims != 4) return invalid_arguments;
671 const dims_t block_dims = { 16, 1, 1, 1 };
675 return fill_contiguous_blocked(md, block_dims, perm);
678 status_t fill_Oidhw4o(memory_desc_t &md) {
679 if (md.ndims != 5) return invalid_arguments;
681 const dims_t block_dims = { 4, 1, 1, 1, 1 };
685 return fill_contiguous_blocked(md, block_dims, perm);
688 status_t fill_Oidhw16o(memory_desc_t &md) {
689 if (md.ndims != 5) return invalid_arguments;
691 const dims_t block_dims = {16, 1, 1, 1, 1};
695 return fill_contiguous_blocked(md, block_dims, perm);
698 status_t fill_Ohwi8o(memory_desc_t &md) {
699 if (md.ndims != 4) return invalid_arguments;
701 const dims_t block_dims = {8, 1, 1, 1};
705 return fill_contiguous_blocked(md, block_dims, perm);
708 status_t fill_Ohwi4o(memory_desc_t &md) {
709 if (md.ndims != 4) return invalid_arguments;
711 const dims_t block_dims = {4, 1, 1, 1};
715 return fill_contiguous_blocked(md, block_dims, perm);
718 status_t fill_Ohwi16o(memory_desc_t &md) {
719 if (md.ndims != 4) return invalid_arguments;
721 const dims_t block_dims = { 16, 1, 1, 1 };
725 return fill_contiguous_blocked(md, block_dims, perm);
728 status_t fill_Odhwi16o(memory_desc_t &md) {
729 if (md.ndims != 5) return invalid_arguments;
731 const dims_t block_dims = {16, 1, 1, 1, 1};
735 return fill_contiguous_blocked(md, block_dims, perm);
738 status_t fill_Odhwi8o(memory_desc_t &md) {
739 if (md.ndims != 5) return invalid_arguments;
741 const dims_t block_dims = {8, 1, 1, 1, 1};
745 return fill_contiguous_blocked(md, block_dims, perm);
748 status_t fill_goiw(memory_desc_t &md) {
749 if (md.ndims != 4) return invalid_arguments;
751 const int perm[4] = {0, 1, 2, 3};
752 return fill_nonblocked(md, perm);
755 status_t fill_gOwi4o(memory_desc_t &md) {
756 if (md.ndims != 4) return invalid_arguments;
758 const dims_t block_dims = {1, 4, 1, 1};
762 return fill_contiguous_blocked(md, block_dims, perm);
765 status_t fill_gOwi8o(memory_desc_t &md) {
766 if (md.ndims != 4) return invalid_arguments;
768 const dims_t block_dims = { 1, 8, 1, 1 };
772 return fill_contiguous_blocked(md, block_dims, perm);
775 status_t fill_gOIw8o8i(memory_desc_t &md) {
776 if (md.ndims != 4) return invalid_arguments;
778 const dims_t block_dims = { 1, 8, 8, 1 };
782 return fill_contiguous_blocked(md, block_dims, perm);
785 status_t fill_gOIw4i4o(memory_desc_t &md) {
786 if (md.ndims != 4) return invalid_arguments;
788 const dims_t block_dims = { 1, 4, 4, 1 };
792 return fill_contiguous_blocked(md, block_dims, perm);
795 status_t fill_gOIw8i8o(memory_desc_t &md) {
796 if (md.ndims != 4) return invalid_arguments;
798 const dims_t block_dims = {1, 8, 8, 1};
802 return fill_contiguous_blocked(md, block_dims, perm);
805 status_t fill_gOIw16i16o(memory_desc_t &md) {
806 if (md.ndims != 4) return invalid_arguments;
808 const dims_t block_dims = {1, 16, 16, 1};
812 return fill_contiguous_blocked(md, block_dims, perm);
815 status_t fill_gOIw16o16i(memory_desc_t &md) {
816 if (md.ndims != 4) return invalid_arguments;
818 const dims_t block_dims = {1, 16, 16, 1};
822 return fill_contiguous_blocked(md, block_dims, perm);
825 status_t fill_gOiw4o(memory_desc_t &md) {
826 if (md.ndims != 4) return invalid_arguments;
828 const dims_t block_dims = { 1, 4, 1, 1 };
832 return fill_contiguous_blocked(md, block_dims, perm);
835 status_t fill_gOiw16o(memory_desc_t &md) {
836 if (md.ndims != 4) return invalid_arguments;
838 const dims_t block_dims = {1, 16, 1, 1};
842 return fill_contiguous_blocked(md, block_dims, perm);
845 status_t fill_gOwi16o(memory_desc_t &md) {
846 if (md.ndims != 4) return invalid_arguments;
848 const dims_t block_dims = {1, 16, 1, 1};
852 return fill_contiguous_blocked(md, block_dims, perm);
855 status_t fill_gOIw8i16o2i(memory_desc_t &md) {
856 if (md.ndims != 4) return invalid_arguments;
858 const dims_t block_dims = {1, 16, 16, 1};
862 return fill_contiguous_blocked(md, block_dims, perm);
865 status_t fill_gOIw8o16i2o(memory_desc_t &md) {
866 if (md.ndims != 4) return invalid_arguments;
868 const dims_t block_dims = {1, 16, 16, 1};
872 return fill_contiguous_blocked(md, block_dims, perm);
875 status_t fill_gIOw16o16i(memory_desc_t &md) {
876 if (md.ndims != 4) return invalid_arguments;
878 const dims_t block_dims = {1, 16, 16, 1};
882 return fill_contiguous_blocked(md, block_dims, perm);
885 status_t fill_goihw(memory_desc_t &md) {
886 if (md.ndims != 5) return invalid_arguments;
888 const int perm[5] = {0, 1, 2, 3, 4};
889 return fill_nonblocked(md, perm);
892 status_t fill_hwigo(memory_desc_t &md) {
893 if (md.ndims != 5) return invalid_arguments;
895 const int perm[5] = {3, 4, 2, 0, 1};
896 return fill_nonblocked(md, perm);
899 status_t fill_giohw(memory_desc_t &md) {
900 if (md.ndims != 5) return invalid_arguments;
902 const int perm[5] = {0, 2, 1, 3, 4};
903 return fill_nonblocked(md, perm);
906 status_t fill_gOIhw4o4i(memory_desc_t &md) {
907 if (md.ndims != 5) return invalid_arguments;
909 const dims_t block_dims = {1, 4, 4, 1, 1};
913 return fill_contiguous_blocked(md, block_dims, perm);
916 status_t fill_gOIhw4i4o(memory_desc_t &md) {
917 if (md.ndims != 5) return invalid_arguments;
919 const dims_t block_dims = { 1, 4, 4, 1, 1 };
923 return fill_contiguous_blocked(md, block_dims, perm);
926 status_t fill_gOIhw8i8o(memory_desc_t &md) {
927 if (md.ndims != 5) return invalid_arguments;
929 const dims_t block_dims = { 1, 8, 8, 1, 1 };
933 return fill_contiguous_blocked(md, block_dims, perm);
936 status_t fill_gOIhw16i16o(memory_desc_t &md) {
937 if (md.ndims != 5) return invalid_arguments;
939 const dims_t block_dims = {1, 16, 16, 1, 1};
943 return fill_contiguous_blocked(md, block_dims, perm);
946 status_t fill_gOIdhw16i16o(memory_desc_t &md) {
947 if (md.ndims != 6) return invalid_arguments;
949 const dims_t block_dims = {1, 16, 16, 1, 1, 1};
953 return fill_contiguous_blocked(md, block_dims, perm);
956 status_t fill_gOIdhw4i4o(memory_desc_t &md) {
957 if (md.ndims != 6) return invalid_arguments;
959 const dims_t block_dims = {1, 4, 4, 1, 1, 1};
963 return fill_contiguous_blocked(md, block_dims, perm);
966 status_t fill_gOIdhw8i8o(memory_desc_t &md) {
967 if (md.ndims != 6) return invalid_arguments;
969 const dims_t block_dims = { 1, 8, 8, 1, 1, 1 };
972 6, 8, 7, 9, 10, 11 };
973 return fill_contiguous_blocked(md, block_dims, perm);
976 status_t fill_gOihw4o(memory_desc_t &md) {
977 if (md.ndims != 5) return invalid_arguments;
979 const dims_t block_dims = {1, 4, 1, 1, 1};
983 return fill_contiguous_blocked(md, block_dims, perm);
986 status_t fill_gOihw16o(memory_desc_t &md) {
987 if (md.ndims != 5) return invalid_arguments;
989 const dims_t block_dims = { 1, 16, 1, 1, 1 };
993 return fill_contiguous_blocked(md, block_dims, perm);
996 status_t fill_gOidhw4o(memory_desc_t &md) {
997 if (md.ndims != 6) return invalid_arguments;
999 const dims_t block_dims = {1, 4, 1, 1, 1, 1};
1000 const int perm[] = {
1002 6, 7, 8, 9, 10, 11};
1003 return fill_contiguous_blocked(md, block_dims, perm);
1006 status_t fill_gOidhw16o(memory_desc_t &md) {
1007 if (md.ndims != 6) return invalid_arguments;
1009 const dims_t block_dims = { 1, 16, 1, 1, 1, 1 };
1010 const int perm[] = {
1012 6, 7, 8, 9, 10, 11 };
1013 return fill_contiguous_blocked(md, block_dims, perm);
1016 status_t fill_gOhwi8o(memory_desc_t &md) {
1017 if (md.ndims != 5) return invalid_arguments;
1019 const dims_t block_dims = {1, 8, 1, 1, 1};
1020 const int perm[] = {
1023 return fill_contiguous_blocked(md, block_dims, perm);
1026 status_t fill_gOhwi4o(memory_desc_t &md) {
1027 if (md.ndims != 5) return invalid_arguments;
1029 const dims_t block_dims = {1, 4, 1, 1, 1};
1030 const int perm[] = {
1033 return fill_contiguous_blocked(md, block_dims, perm);
1036 status_t fill_gOhwi16o(memory_desc_t &md) {
1037 if (md.ndims != 5) return invalid_arguments;
1039 const dims_t block_dims = { 1, 16, 1, 1, 1 };
1040 const int perm[] = {
1043 return fill_contiguous_blocked(md, block_dims, perm);
1046 status_t fill_gOdhwi16o(memory_desc_t &md) {
1047 if (md.ndims != 6) return invalid_arguments;
1049 const dims_t block_dims = {1, 16, 1, 1, 1, 1};
1050 const int perm[] = {
1052 6, 7, 8, 9, 10, 11};
1053 return fill_contiguous_blocked(md, block_dims, perm);
1056 status_t fill_gOdhwi8o(memory_desc_t &md) {
1057 if (md.ndims != 6) return invalid_arguments;
1059 const dims_t block_dims = {1, 8, 1, 1, 1, 1};
1060 const int perm[] = {
1062 6, 7, 8, 9, 10, 11};
1063 return fill_contiguous_blocked(md, block_dims, perm);
1066 status_t fill_gOIhw4i16o4i(memory_desc_t &md) {
1067 if (md.ndims != 5) return invalid_arguments;
1069 const dims_t block_dims = {1, 16, 16, 1, 1};
1070 const int perm[] = {
1073 return fill_contiguous_blocked(md, block_dims, perm);
1076 status_t fill_gOIhw2i8o4i(memory_desc_t &md) {
1077 if (md.ndims != 5) return invalid_arguments;
1079 const dims_t block_dims = {1, 8, 8, 1, 1};
1080 const int perm[] = {
1083 return fill_contiguous_blocked(md, block_dims, perm);
1086 status_t fill_gOhIw8o4i(memory_desc_t &md) {
1087 if (md.ndims != 5) return invalid_arguments;
1089 const dims_t block_dims = {1, 8, 4, 1, 1};
1090 const int perm[] = {
1093 return fill_contiguous_blocked(md, block_dims, perm);
1096 status_t fill_Goihw8g(memory_desc_t &md) {
1097 if (md.ndims != 5) return invalid_arguments;
1099 const dims_t block_dims = {8, 1, 1, 1, 1};
1100 const int perm[] = {
1103 return fill_contiguous_blocked(md, block_dims, perm);
1106 status_t fill_Goihw16g(memory_desc_t &md) {
1107 if (md.ndims != 5) return invalid_arguments;
1109 const dims_t block_dims = {16, 1, 1, 1, 1};
1110 const int perm[] = {
1113 return fill_contiguous_blocked(md, block_dims, perm);
1116 status_t fill_gOIhw8i16o2i(memory_desc_t &md) {
1117 if (md.ndims != 5) return invalid_arguments;
1119 const dims_t block_dims = {1, 16, 16, 1, 1};
1120 const int perm[] = {
1123 return fill_contiguous_blocked(md, block_dims, perm);
1126 status_t fill_gOIdhw8i16o2i(memory_desc_t &md) {
1127 if (md.ndims != 6) return invalid_arguments;
1129 const dims_t block_dims = {1, 16, 16, 1, 1, 1};
1130 const int perm[] = {
1132 6, 8, 7, 9, 10, 11};
1133 return fill_contiguous_blocked(md, block_dims, perm);
1136 status_t fill_gOIhw8o8i(memory_desc_t &md) {
1137 if (md.ndims != 5) return invalid_arguments;
1139 const dims_t block_dims = {1, 8, 8, 1, 1};
1140 const int perm[] = {
1143 return fill_contiguous_blocked(md, block_dims, perm);
1146 status_t fill_gOIhw16o16i(memory_desc_t &md) {
1147 if (md.ndims != 5) return invalid_arguments;
1149 const dims_t block_dims = {1, 16, 16, 1, 1};
1150 const int perm[] = {
1153 return fill_contiguous_blocked(md, block_dims, perm);
1156 status_t fill_gOIdhw16o16i(memory_desc_t &md) {
1157 if (md.ndims != 6) return invalid_arguments;
1159 const dims_t block_dims = {1, 16, 16, 1, 1, 1};
1160 const int perm[] = {
1162 6, 7, 8, 9, 10, 11};
1163 return fill_contiguous_blocked(md, block_dims, perm);
1166 status_t fill_gOIdhw8o8i(memory_desc_t &md) {
1167 if (md.ndims != 6) return invalid_arguments;
1169 const dims_t block_dims = {1, 8, 8, 1, 1, 1};
1170 const int perm[] = {
1172 6, 7, 8, 9, 10, 11};
1173 return fill_contiguous_blocked(md, block_dims, perm);
1176 status_t fill_gIOhw16o16i(memory_desc_t &md) {
1177 if (md.ndims != 5) return invalid_arguments;
1179 const dims_t block_dims = {1, 16, 16, 1, 1};
1180 const int perm[] = {
1183 return fill_contiguous_blocked(md, block_dims, perm);
1186 status_t fill_gOIhw8o16i2o(memory_desc_t &md) {
1187 if (md.ndims != 5) return invalid_arguments;
1189 const dims_t block_dims = {1, 16, 16, 1, 1};
1190 const int perm[] = {
1193 return fill_contiguous_blocked(md, block_dims, perm);
1196 status_t fill_ntc(memory_desc_t &md) {
1197 if (md.ndims != 3) return invalid_arguments;
1199 const int perm[3] = { 1, 0, 2 };
1200 return fill_nonblocked(md, perm);
1203 status_t fill_tnc(memory_desc_t &md) {
1204 if (md.ndims != 3) return invalid_arguments;
1205 const int perm[3] = { 0, 1, 2 };
1206 return fill_nonblocked(md, perm);
1209 status_t fill_ldsnc(memory_desc_t &md) {
1210 if (md.ndims != 5) return invalid_arguments;
1211 const int perm[5] = { 0, 1, 2, 3, 4 };
1212 return fill_nonblocked(md, perm);
1215 status_t fill_ldigo(memory_desc_t &md) {
1216 if (md.ndims != 5) return invalid_arguments;
1218 const int perm[5] = { 0, 1, 2, 3, 4 };
1219 return fill_nonblocked(md, perm);
1222 status_t fill_ldgoi(memory_desc_t &md) {
1223 if (md.ndims != 5) return invalid_arguments;
1225 const int perm[5] = { 0, 1, 3, 4, 2 };
1226 return fill_nonblocked(md, perm);
1229 status_t fill_ldgo(memory_desc_t &md) {
1230 if (md.ndims != 4) return invalid_arguments;
1232 const int perm[4] = { 0, 1, 2, 3 };
1233 return fill_nonblocked(md, perm);
1238 status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc)
1240 if (memory_desc.ndims == 0) return invalid_arguments;
1242 switch (memory_desc.format) {
1243 case x: return fill_x(memory_desc);
1244 case nc: return fill_nc(memory_desc);
1245 case ncw: return fill_ncw(memory_desc);
1246 case nwc: return fill_nwc(memory_desc);
1247 case nCw4c: return fill_nCw4c(memory_desc);
1248 case nCw8c: return fill_nCw8c(memory_desc);
1249 case nCw16c: return fill_nCw16c(memory_desc);
1250 case nchw: return fill_nchw(memory_desc);
1251 case nhwc: return fill_nhwc(memory_desc);
1252 case chwn: return fill_chwn(memory_desc);
1253 case nChw4c: return fill_nChw4c(memory_desc);
1254 case nChw8c: case oIhw8i: return fill_nChw8c(memory_desc);
1255 case nChw16c: case oIhw16i: return fill_nChw16c(memory_desc);
1256 case oi: return fill_oi(memory_desc);
1257 case io: return fill_io(memory_desc);
1258 case oiw: return fill_oiw(memory_desc);
1259 case wio: return fill_wio(memory_desc);
1260 case Owi4o: return fill_Owi4o(memory_desc);
1261 case OIw4i4o: return fill_OIw4i4o(memory_desc);
1262 case Owi8o: return fill_Owi8o(memory_desc);
1263 case OIw8o8i: return fill_OIw8o8i(memory_desc);
1264 case OIw8i8o: return fill_OIw8i8o(memory_desc);
1265 case OIw16i16o: return fill_OIw16i16o(memory_desc);
1266 case OIw16o16i: return fill_OIw16o16i(memory_desc);
1267 case Oiw4o: return fill_Oiw4o(memory_desc);
1268 case Oiw16o: return fill_Oiw16o(memory_desc);
1269 case Owi16o: return fill_Owi16o(memory_desc);
1270 case OIw8i16o2i: return fill_OIw8i16o2i(memory_desc);
1271 case OIw8o16i2o: return fill_OIw8o16i2o(memory_desc);
1272 case IOw16o16i: return fill_IOw16o16i(memory_desc);
1273 case oihw: return fill_oihw(memory_desc);
1274 case ihwo: return fill_ihwo(memory_desc);
1275 case hwio: return fill_hwio(memory_desc);
1276 case iohw: return fill_iohw(memory_desc);
1277 case hwio_s8s8: return fill_hwio(memory_desc);
1278 case dhwio: return fill_dhwio(memory_desc);
1279 case OIhw4i4o: return fill_OIhw4i4o(memory_desc);
1280 case OIhw8i8o: return fill_OIhw8i8o(memory_desc);
1281 case OIhw16i16o: return fill_OIhw16i16o(memory_desc);
1282 case OIhw4i16o4i: return fill_OIhw4i16o4i(memory_desc);
1283 case OhIw8o4i: return fill_OhIw8o4i(memory_desc);
1284 case OhIw8o32i: return fill_OhIw8o32i(memory_desc);
1285 case OhIw16o32i: return fill_OhIw16o32i(memory_desc);
1286 case OhIw8o4i_s8s8: return fill_OhIw8o4i(memory_desc);
1287 case OIhw4i16o4i_s8s8: return fill_OIhw4i16o4i(memory_desc);
1288 case OIhw8i16o2i: return fill_OIhw8i16o2i(memory_desc);
1289 case OIdhw8i16o2i: return fill_OIdhw8i16o2i(memory_desc);
1290 case OIhw8o16i2o: return fill_OIhw8o16i2o(memory_desc);
1291 case OIhw8o8i: return fill_OIhw8o8i(memory_desc);
1292 case OIhw16o16i: return fill_OIhw16o16i(memory_desc);
1293 case IOhw16o16i: return fill_IOhw16o16i(memory_desc);
1294 case Oihw4o: return fill_Oihw4o(memory_desc);
1295 case Oihw16o: return fill_Oihw16o(memory_desc);
1296 case Ohwi8o: return fill_Ohwi8o(memory_desc);
1297 case Ohwi4o: return fill_Ohwi4o(memory_desc);
1298 case Ohwi16o: return fill_Ohwi16o(memory_desc);
1299 case goiw: return fill_goiw(memory_desc);
1300 case gOwi4o: return fill_gOwi4o(memory_desc);
1301 case gOIw4i4o: return fill_gOIw4i4o(memory_desc);
1302 case gOwi8o: return fill_gOwi8o(memory_desc);
1303 case gOIw8o8i: return fill_gOIw8o8i(memory_desc);
1304 case gOIw8i8o: return fill_gOIw8i8o(memory_desc);
1305 case gOIw16i16o: return fill_gOIw16i16o(memory_desc);
1306 case gOIw16o16i: return fill_gOIw16o16i(memory_desc);
1307 case gOiw4o: return fill_gOiw4o(memory_desc);
1308 case gOiw16o: return fill_gOiw16o(memory_desc);
1309 case gOwi16o: return fill_gOwi16o(memory_desc);
1310 case gOIw8i16o2i: return fill_gOIw8i16o2i(memory_desc);
1311 case gOIw8o16i2o: return fill_gOIw8o16i2o(memory_desc);
1312 case gIOw16o16i: return fill_gIOw16o16i(memory_desc);
1313 case goihw: return fill_goihw(memory_desc);
1314 case hwigo: return fill_hwigo(memory_desc);
1315 case giohw: return fill_giohw(memory_desc);
1316 case hwigo_s8s8: return fill_hwigo(memory_desc);
1317 case gOIhw4i4o: return fill_gOIhw4i4o(memory_desc);
1318 case gOIhw8i8o: return fill_gOIhw8i8o(memory_desc);
1319 case gOIhw16i16o: return fill_gOIhw16i16o(memory_desc);
1320 case gOIhw4i16o4i: return fill_gOIhw4i16o4i(memory_desc);
1321 case gOhIw8o4i: return fill_gOhIw8o4i(memory_desc);
1322 case gOhIw8o4i_s8s8: return fill_gOhIw8o4i(memory_desc);
1323 case gOIhw4i16o4i_s8s8: return fill_gOIhw4i16o4i(memory_desc);
1324 case gOIhw2i8o4i: return fill_gOIhw2i8o4i(memory_desc);
1325 case gOIhw2i8o4i_s8s8: return fill_gOIhw2i8o4i(memory_desc);
1326 case gOIhw8i16o2i: return fill_gOIhw8i16o2i(memory_desc);
1327 case gOIdhw8i16o2i: return fill_gOIdhw8i16o2i(memory_desc);
1328 case gOIhw8o16i2o: return fill_gOIhw8o16i2o(memory_desc);
1329 case gOIhw4o4i: return fill_gOIhw4o4i(memory_desc);
1330 case gOIhw4o4i_s8s8: return fill_gOIhw4o4i(memory_desc);
1331 case gOIhw8o8i: return fill_gOIhw8o8i(memory_desc);
1332 case gOIhw16o16i: return fill_gOIhw16o16i(memory_desc);
1333 case gIOhw16o16i: return fill_gIOhw16o16i(memory_desc);
1334 case gOihw4o: return fill_gOihw4o(memory_desc);
1335 case gOihw16o: return fill_gOihw16o(memory_desc);
1336 case gOhwi8o: return fill_gOhwi8o(memory_desc);
1337 case gOhwi4o: return fill_gOhwi4o(memory_desc);
1338 case gOhwi16o: return fill_gOhwi16o(memory_desc);
1339 case Goihw8g: return fill_Goihw8g(memory_desc);
1340 case Goihw16g: return fill_Goihw16g(memory_desc);
1341 case Goihw16g_s8s8: return fill_Goihw16g(memory_desc);
1342 case ncdhw: return fill_ncdhw(memory_desc);
1343 case ndhwc: return fill_ndhwc(memory_desc);
1344 case oidhw: return fill_oidhw(memory_desc);
1345 case goidhw: return fill_goidhw(memory_desc);
1346 case nCdhw4c: return fill_nCdhw4c(memory_desc);
1347 case nCdhw8c: case oIdhw8i: return fill_nCdhw8c(memory_desc);
1348 case nCdhw16c: case oIdhw16i: return fill_nCdhw16c(memory_desc);
1349 case OIdhw16i16o: return fill_OIdhw16i16o(memory_desc);
1350 case gOIdhw16i16o: return fill_gOIdhw16i16o(memory_desc);
1351 case OIdhw4i4o: return fill_OIdhw4i4o(memory_desc);
1352 case gOIdhw4i4o: return fill_gOIdhw4i4o(memory_desc);
1353 case OIdhw8i8o: return fill_OIdhw8i8o(memory_desc);
1354 case gOIdhw8i8o: return fill_gOIdhw8i8o(memory_desc);
1355 case OIdhw16o16i: return fill_OIdhw16o16i(memory_desc);
1356 case gOIdhw16o16i: return fill_gOIdhw16o16i(memory_desc);
1357 case OIdhw8o8i: return fill_OIdhw8o8i(memory_desc);
1358 case gOIdhw8o8i: return fill_gOIdhw8o8i(memory_desc);
1359 case Oidhw4o: return fill_Oidhw4o(memory_desc);
1360 case Oidhw16o: return fill_Oidhw16o(memory_desc);
1361 case Odhwi16o: return fill_Odhwi16o(memory_desc);
1362 case Odhwi8o: return fill_Odhwi8o(memory_desc);
1363 case gOidhw4o: return fill_gOidhw4o(memory_desc);
1364 case gOidhw16o: return fill_gOidhw16o(memory_desc);
1365 case gOdhwi16o: return fill_gOdhwi16o(memory_desc);
1366 case gOdhwi8o: return fill_gOdhwi8o(memory_desc);
1367 case ntc: return fill_ntc(memory_desc);
1368 case tnc: return fill_tnc(memory_desc);
1369 case ldsnc: return fill_ldsnc(memory_desc);
1370 case ldigo: return fill_ldigo(memory_desc);
1371 case ldgoi: return fill_ldgoi(memory_desc);
1372 case ldgo: return fill_ldgo(memory_desc);
1374 case rnn_packed: return success;
1378 return invalid_arguments;
1384 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s