Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_memory.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <limits>
6 #include <vector>
7 #include <cmath>
8 #include <algorithm>
9 #include <unordered_set>
10 #include <utility>
11
12 #include <mkldnn_types.h>
13 #include "mkldnn_memory.h"
14 #include "mkldnn_node.h"
15 #include "mkldnn_extension_utils.h"
16
17 using namespace InferenceEngine;
18 using namespace mkldnn;
19
20 namespace MKLDNNPlugin {
21
22 MKLDNNMemory::MKLDNNMemory(const engine& eng) : eng(eng) {}
23
24 size_t MKLDNNMemory::GetSize() const {
25     uint8_t itemSize = MKLDNNExtensionUtils::sizeOfDataType(mkldnn::memory::data_type(GetDataType()));
26
27     auto desc = GetDescriptor();
28     std::vector<int> dims(desc.data.layout_desc.blocking.padding_dims,
29                           desc.data.layout_desc.blocking.padding_dims + desc.data.ndims);
30     return std::accumulate(std::begin(dims), std::end(dims), (size_t) 1, std::multiplies<size_t>()) * itemSize;
31 }
32
33 void MKLDNNMemory::Create(memory::dims dims, memory::data_type data_type, memory::format format, const void* data) {
34     if (!isConsistant(dims, format)) {
35         THROW_IE_EXCEPTION << "dims and format are inconsistent.";
36     }
37
38     if (format == memory::blocked) {
39         format = memory::any;
40     }
41
42     memory::desc desc = MKLDNNMemoryDesc({dims}, data_type, format);
43
44     if (format == memory::any) {
45         CreateBlockingDesc(desc);
46     }
47
48     Create(desc, data);
49 }
50
51 void MKLDNNMemory::Create(const mkldnn::memory::desc& desc, const void *data) {
52     auto primitive_desc = memory::primitive_desc(desc, eng);
53     uint8_t itemSize = MKLDNNExtensionUtils::sizeOfDataType(mkldnn::memory::data_type(desc.data.data_type));
54
55     if (data == nullptr) {
56         prim.reset(new memory(primitive_desc));
57
58         size_t real_size = 0;
59         if (desc.data.format == mkldnn_wino_fmt)
60             return;
61         if (prim->get_primitive_desc().desc().data.ndims > 0) {
62             real_size = static_cast<size_t>(prim->get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[0]);
63             for (int i = 1; i < prim->get_primitive_desc().desc().data.ndims; i++) {
64                 real_size *= prim->get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[i];
65             }
66         }
67         uint8_t* dataPtr = static_cast<uint8_t*>(GetData());
68         dataPtr += itemSize * prim->get_primitive_desc().desc().data.layout_desc.blocking.offset_padding;
69
70         memset(dataPtr, 0, real_size * itemSize);
71     } else {
72         // MKLDNN accepts not a const data, probably need to remove some level of consteness in a call stack
73         prim.reset(new memory(primitive_desc, const_cast<void*>(data)));
74     }
75 }
76
77 void MKLDNNMemory::SetData(memory::data_type dataType, memory::format format, const void* data, size_t size, bool ftz) const {
78     uint8_t itemSize = MKLDNNExtensionUtils::sizeOfDataType(mkldnn::memory::data_type(dataType));
79
80     if (static_cast<mkldnn_memory_format_t>(format) != GetDescriptor().data.format ||
81             GetDataType() != dataType) {
82         auto memData = GetDescriptor().data;
83
84         std::vector<ptrdiff_t> dims(memData.dims, memData.dims + memData.ndims);
85
86         auto dataType = GetDataType();
87
88         MKLDNNMemory src(eng);
89         src.Create(dims, dataType, format, data);
90
91         std::shared_ptr<mkldnn::reorder> pReorder =
92                 std::shared_ptr<mkldnn::reorder>(new mkldnn::reorder(src.GetPrimitive(), GetPrimitive()));
93
94         mkldnn::stream(stream::kind::eager).submit({*pReorder});
95     } else {
96         uint8_t* dataPtr = static_cast<uint8_t*>(GetData());
97         // We cannot support strides for i/o blobs because it affects performance.
98         dataPtr += itemSize * prim->get_primitive_desc().desc().data.layout_desc.blocking.offset_padding;
99         memcpy(dataPtr, data, size);
100     }
101
102     if (ftz && dataType == mkldnn_f32) {
103         // Internal blobs haven't strides yet.
104         auto *memData = static_cast<float *>(GetData());
105         memData += prim->get_primitive_desc().desc().data.layout_desc.blocking.offset_padding;
106         size_t realSize = GetSize() / sizeof(float);
107         for (size_t i = 0; i < realSize; i++) {
108             if (memData[i] != 0 && (fabsf(memData[i]) < std::numeric_limits<float>::min())) {
109                 memData[i] = 0.0f;
110             }
111         }
112     }
113 }
114
115 void MKLDNNMemory::SetData(const MKLDNNMemory& memory, bool ftz) const {
116     mkldnn::reorder reorderPrim(memory.GetPrimitive(), GetPrimitive());
117     mkldnn::stream(stream::kind::eager).submit({reorderPrim});
118
119     if (ftz && memory.GetDataType() == mkldnn::memory::f32 && GetFormat() != mkldnn::memory::wino_fmt) {
120         // Internal blobs haven't strides yet.
121         auto *memData = static_cast<float *>(GetData());
122         memData += prim->get_primitive_desc().desc().data.layout_desc.blocking.offset_padding;
123         size_t realSize = GetSize() / sizeof(float);
124         for (size_t i = 0; i < realSize; i++) {
125             if (memData[i] != 0 && (fabsf(memData[i]) < std::numeric_limits<float>::min())) {
126                 memData[i] = 0.0f;
127             }
128         }
129     }
130 }
131
132 void MKLDNNMemory::FillZero() {
133     void* dataPtr = GetData();
134     memset(dataPtr, 0, GetSize());
135 }
136
137 bool MKLDNNMemory::isConsistant(memory::dims dims, memory::format format) {
138     using f = mkldnn::memory::format;
139
140     size_t ndims = 0;
141
142     switch (format) {
143         case f::x:
144             ndims = 1; break;
145         case f::nc:
146         case f::oi:
147         case f::io:
148             ndims = 2; break;
149         case f::ntc:
150         case f::tnc:
151             ndims = 3; break;
152         case f::nchw:
153         case f::nhwc:
154         case f::chwn:
155         case f::nChw8c:
156         case f::nChw16c:
157         case f::oihw:
158         case f::ihwo:
159         case f::hwio:
160         case f::OIhw8i8o:
161         case f::OIhw16i16o:
162         case f::OIhw8o8i:
163         case f::OIhw16o16i:
164         case f::OIhw8i16o2i:
165         case f::OIhw8o16i2o:
166         case f::Ohwi8o:
167         case f::Ohwi16o:
168         case f::OhIw16o4i:
169         case f::OIhw4i16o4i:
170             ndims = 4; break;
171         // DHW
172         case f::ncdhw:
173         case f::ndhwc:
174         case f::nCdhw8c:
175         case f::nCdhw16c:
176         case f::oidhw:
177         case f::OIdhw8i8o:
178         case f::OIdhw16i16o:
179         case f::OIdhw8o8i:
180         case f::OIdhw16o16i:
181         case f::OIdhw8i16o2i:
182         case f::Odhwi8o:
183         case f::Odhwi16o:
184         // Group HW
185         case f::hwigo:
186         case f::goihw:
187         case f::gOIhw8i8o:
188         case f::gOIhw16i16o:
189         case f::gOIhw8i16o2i:
190         case f::gOIhw8o16i2o:
191         case f::gOhwi8o:
192         case f::gOhwi16o:
193         case f::gOIhw8o8i:
194         case f::gOIhw16o16i:
195         case f::gOhIw16o4i:
196         case f::Goihw8g:
197         case f::Goihw16g:
198             ndims = 5; break;
199         case f::goidhw:
200         case f::gOIdhw8i8o:
201         case f::gOIdhw16i16o:
202         case f::gOIdhw8i16o2i:
203         case f::gOdhwi8o:
204         case f::gOdhwi16o:
205         case f::gOIdhw8o8i:
206         case f::gOIdhw16o16i:
207             ndims = 6; break;
208         case f::format_undef:
209             ndims = 0; break;
210         case f::any:
211         case f::wino_fmt:
212         case f::blocked:
213             return true;
214         default:
215             return false;
216     }
217
218     return (dims.size() == ndims);
219 }
220
221 bool MKLDNNMemory::IsPlainFormat(memory::format format) {
222     std::vector<memory::format> plains = {memory::nc, memory::nchw, memory::ncdhw, memory::nhwc, memory::ndhwc, memory::chwn,
223         memory::oi, memory::io, memory::oihw, memory::oidhw, memory::ihwo, memory::tnc,
224         memory::goihw,
225         memory::blocked};
226
227     for (auto it : plains) {
228         if (format == it) {
229             return true;
230         }
231     }
232
233     return false;
234 }
235
236 memory::format MKLDNNMemory::GetPlainFormat(memory::dims dims) {
237     switch (dims.size()) {
238         case 1:
239             return memory::x;
240         case 2:
241             return memory::nc;
242         case 3:
243             return memory::tnc;
244         case 4:
245             return memory::nchw;
246         case 5:
247             return memory::ncdhw;
248         default:
249             return memory::blocked;
250     }
251 }
252
253 InferenceEngine::Layout MKLDNNMemory::GetPlainLayout(memory::dims dims) {
254     switch (dims.size()) {
255         case 0: return Layout::SCALAR;
256         case 1: return Layout::C;
257         case 2: return Layout::NC;
258         case 3: return Layout::CHW;
259         case 4: return Layout::NCHW;
260         default:
261             return Layout::BLOCKED;
262     }
263 }
264
265 void MKLDNNMemory::CreateBlockingDesc(memory::desc &desc) {
266     auto dims = desc.data.dims;
267     int ndims = desc.data.ndims;
268
269     desc.data.format = mkldnn_blocked;
270
271     auto& blk = desc.data.layout_desc.blocking;
272
273     blk.offset_padding = 0;
274
275     for (int i = 0; i < ndims; i++) {
276         blk.block_dims[i] = 1;
277         blk.strides[1][i] = 1;
278         blk.padding_dims[i] = dims[i];
279         blk.offset_padding_to_data[i] = 0;
280     }
281
282     int perm[TENSOR_MAX_DIMS] = {0};
283
284     for (int i = 0; i < ndims; ++i) {
285         perm[i] = i;
286     }
287
288     blk.strides[0][perm[ndims - 1]] = 1;
289
290     for (int d = 1; d < ndims; ++d) {
291         const int prev_idx = perm[ndims - d];
292         const int curr_idx = perm[ndims - 1 - d];
293
294         blk.strides[0][curr_idx] = dims[curr_idx] == 0 ? 1 : blk.strides[0][prev_idx] * (std::max)((ptrdiff_t)1, dims[prev_idx]);
295     }
296 }
297 memory::format MKLDNNMemory::Convert(const InferenceEngine::Layout layout) {
298     switch (layout) {
299         case NCHW:
300             return memory::nchw;
301         case NHWC:
302             return memory::nhwc;
303         case NCDHW:
304             return memory::ncdhw;
305         case NDHWC:
306             return memory::ndhwc;
307         case CHW:
308             return memory::tnc;
309         case NC:
310             return memory::nc;
311         case C:
312             return memory::x;
313         default:
314             return memory::blocked;
315     }
316 }
317
318 std::string MKLDNNMemory::formatToString(memory::format fmt) {
319     switch (fmt) {
320         case memory::format_undef: return "undef";
321         case memory::any: return "any";
322         case memory::blocked: return "blocked";
323
324         case memory::x: return "x";
325
326         case memory::nc: return "nc";
327         case memory::oi: return "oi";
328         case memory::io: return "io";
329
330         case memory::ntc: return "ntc";
331         case memory::tnc: return "tnc";
332
333         case memory::nchw: return "nchw";
334         case memory::nhwc: return "nhwc";
335         case memory::chwn: return "chwn";
336         case memory::nChw8c: return "nChw8c";
337         case memory::nChw16c: return "nChw16c";
338
339         case memory::ncdhw: return "ncdhw";
340         case memory::ndhwc: return "ndhwc";
341         case memory::nCdhw8c: return "nCdhw8c";
342         case memory::nCdhw16c: return "nCdhw16c";
343
344         case memory::oihw: return "oihw";
345         case memory::ihwo: return "ihwo";
346         case memory::OIhw8i8o: return "OIhw8i8o";
347         case memory::OIhw16i16o: return "OIhw16i16o";
348         case memory::OIhw8o8i: return "OIhw8o8i";
349         case memory::OIhw16o16i: return "OIhw16o16i";
350         case memory::OIhw8i16o2i: return "OIhw8i16o2i";
351         case memory::OIhw8o16i2o: return "OIhw8o16i2o";
352         case memory::Ohwi8o: return "Ohwi8o";
353         case memory::Ohwi16o: return "Ohwi16o";
354         case memory::OhIw16o4i: return "OhIw16o4i";
355
356         case memory::oidhw: return "oidhw";
357         case memory::OIdhw8i8o: return "OIdhw8i8o";
358         case memory::OIdhw16i16o: return "OIdhw16i16o";
359         case memory::OIdhw8o8i: return "OIdhw8o8i";
360         case memory::OIdhw16o16i: return "OIdhw16o16i";
361         case memory::OIdhw8i16o2i: return "OIdhw8i16o2i";
362         case memory::Odhwi8o: return "Odhwi8o";
363         case memory::Odhwi16o: return "Odhwi16o";
364
365         case memory::goihw: return "goihw";
366         case memory::hwigo: return "hwigo";
367         case memory::hwio: return "hwio";
368         case memory::gOIhw8i8o: return "gOIhw8i8o";
369         case memory::gOIhw16i16o: return "gOIhw16i16o";
370         case memory::gOIhw8i16o2i: return "gOIhw8i16o2i";
371         case memory::gOIhw8o16i2o: return "gOIhw8o16i2o";
372         case memory::gOhwi8o: return "gOhwi8o";
373         case memory::gOhwi16o: return "gOhwi16o";
374         case memory::gOIhw8o8i: return "gOIhw8o8i";
375         case memory::gOIhw16o16i: return "gOIhw16o16i";
376         case memory::gOhIw16o4i: return "gOhIw16o4i";
377
378         case memory::goidhw: return "goidhw";
379         case memory::gOIdhw8i8o: return "gOIdhw8i8o";
380         case memory::gOIdhw16i16o: return "gOIdhw16i16o";
381         case memory::gOIdhw8i16o2i: return "gOIdhw8i16o2i";
382         case memory::gOdhwi8o: return "gOdhwi8o";
383         case memory::gOdhwi16o: return "gOdhwi16o";
384         case memory::gOIdhw8o8i: return "gOIdhw8o8i";
385         case memory::gOIdhw16o16i: return "gOIdhw16o16i";
386
387         default: {
388             THROW_IE_EXCEPTION << "Unknown data format.";
389         }
390     }
391 }
392
393 bool MKLDNNMemoryDesc::operator==(const MKLDNNMemoryDesc &rhs) const {
394     auto dims_equal = [] (mkldnn_memory_desc_t ldata, mkldnn_memory_desc_t rdata) {
395         if (ldata.ndims != rdata.ndims)
396             return false;
397         for (int i = 0; i < ldata.ndims; i++) {
398             if (ldata.dims[i] != rdata.dims[i])
399                 return false;
400         }
401         return true;
402     };
403     auto blocking_equal = [] (mkldnn_memory_desc_t ldata, mkldnn_memory_desc_t rdata) {
404         if (ldata.ndims != rdata.ndims)
405             return false;
406         mkldnn_blocking_desc_t lblock = ldata.layout_desc.blocking;
407         mkldnn_blocking_desc_t rblock = rdata.layout_desc.blocking;
408         if (lblock.offset_padding != rblock.offset_padding)
409             return false;
410         for (int i = 0; i < ldata.ndims; i++) {
411             if (lblock.block_dims[i] != rblock.block_dims[i] ||
412                 lblock.offset_padding_to_data[i] != rblock.offset_padding_to_data[i] ||
413                 lblock.padding_dims[i] != rblock.padding_dims[i] || lblock.strides[0][i] != rblock.strides[0][i] ||
414                 lblock.strides[1][i] != rblock.strides[1][i])
415                 return false;
416         }
417         return true;
418     };
419     return dims_equal(this->desc.data, rhs.desc.data) &&
420            this->desc.data.data_type == rhs.desc.data.data_type &&
421            this->desc.data.format == rhs.desc.data.format &&
422            this->desc.data.primitive_kind == rhs.desc.data.primitive_kind &&
423            blocking_equal(this->desc.data, rhs.desc.data);
424 }
425
426 bool MKLDNNMemoryDesc::operator!=(const MKLDNNMemoryDesc &rhs) const {
427     return !(*this == rhs);
428 }
429
430 MKLDNNMemoryDesc::operator mkldnn::memory::desc() const {
431     return desc;
432 }
433
434 MKLDNNMemoryDesc::MKLDNNMemoryDesc(mkldnn::memory::dims dims, mkldnn::memory::data_type dataType,
435                                    mkldnn::memory::format format): desc(dims, dataType, mkldnn::memory::any) {
436     if (format != memory::blocked) {
437         desc = mkldnn::memory::desc(dims, dataType, format);
438         return;
439     }
440     MKLDNNMemory::CreateBlockingDesc(desc);
441 }
442
443 MKLDNNMemoryDesc::operator InferenceEngine::TensorDesc() const {
444     Precision precision;
445     switch (desc.data.data_type) {
446         case mkldnn_f32:
447             precision = Precision::FP32;
448             break;
449         case mkldnn_u8:
450             precision = Precision::U8;
451             break;
452         case mkldnn_s8:
453             precision = Precision::I8;
454             break;
455         case mkldnn_s16:
456             precision = Precision::I16;
457             break;
458         case mkldnn_s32:
459             precision = Precision::I32;
460             break;
461         case mkldnn_bin:
462             precision = Precision::BIN;
463             break;
464         default:
465             THROW_IE_EXCEPTION << "Cannot cast to TensorDesc. Unsupported precision!";
466     }
467     Layout layout;
468     SizeVector order;
469     SizeVector blkDims;
470     auto blkInfo = desc.data.layout_desc.blocking;
471     auto offset = static_cast<size_t>(blkInfo.offset_padding);
472     SizeVector offsetsForDims;
473     SizeVector dims = getDims().ToSizeVector();
474     switch (getFormat()) {
475         case memory::format_undef:
476             THROW_IE_EXCEPTION << "Cannot cast to tensor desc. Format is undefined!";
477         case memory::any:
478             layout = Layout::ANY;
479             return TensorDesc(precision, dims, layout);
480         case memory::x:
481             layout = Layout::C;
482             order = {0};
483             blkDims = dims;
484             break;
485         case memory::oi:
486         case memory::nc:
487             layout = Layout::NC;
488             order = {0, 1};
489             blkDims = dims;
490             break;
491         case memory::tnc:
492             layout = Layout::CHW;
493             order = {0, 1, 2};
494             blkDims = dims;
495             break;
496         case memory::ntc:
497             layout = Layout::CHW;
498             order = {1, 0, 2};
499             blkDims = {static_cast<size_t>(dims[1]),
500                        static_cast<size_t>(dims[0]),
501                        static_cast<size_t>(dims[2])};
502             break;
503         case memory::oihw:
504         case memory::nchw:
505             layout = Layout::NCHW;
506             order = {0, 1, 2, 3};
507             blkDims = dims;
508             break;
509         case memory::ncdhw:
510             layout = Layout::NCDHW;
511             order = {0, 1, 2, 3, 4};
512             blkDims = dims;
513             break;
514         case memory::nhwc:
515             layout = Layout::NHWC;
516             order = {0, 2, 3, 1};
517             if (precision == Precision::BIN) {
518                 blkDims = {static_cast<size_t>(dims[0]),
519                            static_cast<size_t>(dims[2]),
520                            static_cast<size_t>(dims[3]),
521                            static_cast<size_t>(rnd_up(dims[1], 8))};
522             } else {
523                 blkDims = {static_cast<size_t>(dims[0]),
524                            static_cast<size_t>(dims[2]),
525                            static_cast<size_t>(dims[3]),
526                            static_cast<size_t>(dims[1])};
527             }
528             break;
529         case memory::ndhwc:
530             layout = Layout::NDHWC;
531             order = {0, 2, 3, 4, 1};
532             blkDims = {static_cast<size_t>(dims[0]),
533                        static_cast<size_t>(dims[2]),
534                        static_cast<size_t>(dims[3]),
535                        static_cast<size_t>(dims[4]),
536                        static_cast<size_t>(dims[1])};
537             break;
538         case memory::oIhw8i:
539         case memory::nChw8c:
540             order = {0, 1, 2, 3, 1};
541             blkDims = dims;
542             blkDims[1] = blkDims[1] / 8 + (blkDims[1] % 8 ? 1 : 0);
543             blkDims.push_back(8);
544             layout = Layout::BLOCKED;
545             break;
546         case memory::nCdhw8c:
547             order = {0, 1, 2, 3, 4, 1};
548             blkDims = dims;
549             blkDims[1] = blkDims[1] / 8 + (blkDims[1] % 8 ? 1 : 0);
550             blkDims.push_back(8);
551             layout = Layout::BLOCKED;
552             break;
553         case memory::nChw16c:
554             order = {0, 1, 2, 3, 1};
555             blkDims = dims;
556             blkDims[1] = blkDims[1] / 16 + (blkDims[1] % 16 ? 1 : 0);
557             blkDims.push_back(16);
558             layout = Layout::BLOCKED;
559             break;
560         case memory::nCdhw16c:
561             order = {0, 1, 2, 3, 4, 1};
562             blkDims = dims;
563             blkDims[1] = blkDims[1] / 16 + (blkDims[1] % 16 ? 1 : 0);
564             blkDims.push_back(16);
565             layout = Layout::BLOCKED;
566             break;
567         case memory::blocked:
568             order.clear();
569             blkDims = dims;
570             for (size_t i = 0; i < blkDims.size(); i++) {
571                 order.push_back(i);
572                 if ((i && blkInfo.strides[0][i - 1] < blkInfo.strides[0][i]) || blkInfo.block_dims[i] != 1) {
573                     THROW_IE_EXCEPTION << "Cannot cast to tensor desc."
574                                        << " Unsupported blocked format.";
575                 }
576             }
577             if (order.size() == 3 && order[0] == 0 && order[1] == 1 && order[2] == 2)
578                 layout = Layout::CHW;
579             else
580                 layout = Layout::BLOCKED;
581             break;
582         default:
583             THROW_IE_EXCEPTION << "Cannot cast to tensor desc. Format is unsupported!";
584     }
585
586     SizeVector strides(blkDims.size());
587
588     if (layout == Layout::NHWC || layout == Layout::NDHWC || layout == Layout::CHW) {
589         for (size_t i = 0; i < order.size(); i++) {
590             strides[i] = static_cast<size_t>(blkInfo.strides[0][order[i]]);
591         }
592     } else {
593         strides[blkDims.size() - 1] = 1;
594         for (size_t i = 2; i <= order.size(); i++) {
595             if (blkDims.size() - i < dims.size()) {
596                 strides[blkDims.size() - i] = static_cast<size_t>(blkInfo.strides[0][order[blkDims.size() - i]]);
597             } else {
598                 strides[blkDims.size() - i] = strides[blkDims.size() - i + 1] * blkDims[blkDims.size() - i + 1];
599             }
600         }
601     }
602
603     for (size_t i = 0; i < blkDims.size() && i < TENSOR_MAX_DIMS; i++) {
604         if (i < dims.size())
605             offsetsForDims.push_back(blkInfo.offset_padding_to_data[i]);
606         else
607             offsetsForDims.push_back(0);
608     }
609
610     TensorDesc tensorDesc(precision, dims, {blkDims, order, offset, offsetsForDims, strides});
611
612     tensorDesc.setLayout(layout);
613     return tensorDesc;
614 }
615
616 MKLDNNMemoryDesc::MKLDNNMemoryDesc(const TensorDesc& tDesc):
617         desc({}, mkldnn::memory::data_type::f32, mkldnn::memory::format::format_undef) {
618     mkldnn::memory::data_type data_type;
619     switch (tDesc.getPrecision()) {
620         case Precision::FP32:
621             data_type = mkldnn::memory::data_type::f32;
622             break;
623         case Precision::U8:
624             data_type = mkldnn::memory::data_type::u8;
625             break;
626         case Precision::I8:
627             data_type = mkldnn::memory::data_type::s8;
628             break;
629         case Precision::I16:
630             data_type = mkldnn::memory::data_type::s16;
631             break;
632         case Precision::I32:
633             data_type = mkldnn::memory::data_type::s32;
634             break;
635         case Precision::BIN:
636             data_type = mkldnn::memory::data_type::bin;
637             break;
638         default:
639             THROW_IE_EXCEPTION << "Cannot create MKLDNNMemoryDesc from TensorDesc. Unsupported precision!";
640     }
641
642     mkldnn::memory::format mkldnnFormat = memory::format::format_undef;
643     SizeVector blkdDims = tDesc.getBlockingDesc().getBlockDims();
644     SizeVector order = tDesc.getBlockingDesc().getOrder();
645     SizeVector offsetsToData = tDesc.getBlockingDesc().getOffsetPaddingToData();
646     SizeVector strides = tDesc.getBlockingDesc().getStrides();
647     auto realDims = MKLDNNDims(tDesc.getDims());
648     switch (tDesc.getLayout()) {
649         case ANY:
650             mkldnnFormat = memory::format::any;
651             break;
652         case NCHW:
653             mkldnnFormat = memory::format::nchw;
654             break;
655         case NCDHW:
656             mkldnnFormat = memory::format::ncdhw;
657             break;
658         case NHWC:
659             mkldnnFormat = memory::format::nhwc;
660             break;
661         case NDHWC:
662             mkldnnFormat = memory::format::ndhwc;
663             break;
664         case OIHW:
665             mkldnnFormat = memory::format::oihw;
666             break;
667         case SCALAR:
668         case C:
669             mkldnnFormat = memory::format::x;
670             break;
671         case CHW:
672             if (order == SizeVector{0, 1, 2})
673                 mkldnnFormat = memory::format::tnc;
674             else if (order == SizeVector{1, 0, 2})
675                 mkldnnFormat = memory::format::ntc;
676             else
677                 mkldnnFormat = memory::format::blocked;
678             break;
679         case HW:
680         case NC:
681             mkldnnFormat = memory::format::nc;
682             break;
683         case BLOCKED:
684             mkldnnFormat = memory::format::blocked;
685             if (realDims.ndims() == 1) {
686                 mkldnnFormat = memory::format::x;
687             } else if (realDims.ndims() == 2) {
688                 mkldnnFormat = memory::format::nc;
689             } else if (realDims.ndims() == 4) {
690                 if (order.size() == 5 && order[0] == 0 && order[1] == 1 && order[2] == 2 && order[3] == 3 && order[4] == 1) {
691                     if (blkdDims[4] == 8) {
692                         mkldnnFormat = memory::format::nChw8c;
693                     } else if (blkdDims[4] == 16) {
694                         mkldnnFormat = memory::format::nChw16c;
695                     }
696                 } else if (order.size() == 4) {
697                     if (order[0] == 0 && order[1] == 1 && order[2] == 2 && order[3] == 3) {
698                         mkldnnFormat = memory::format::nchw;
699                     } else if (order[0] == 0 && order[1] == 2 && order[2] == 3 && order[3] == 1) {
700                         mkldnnFormat = memory::format::nhwc;
701                     }
702                 }
703             } else if (realDims.ndims() == 5) {
704                 if (order.size() == 6 &&
705                         order[0] == 0 && order[1] == 1 && order[2] == 2 && order[3] == 3 && order[4] == 4 && order[5] == 1) {
706                     if (blkdDims[5] == 8) {
707                         mkldnnFormat = memory::format::nCdhw8c;
708                     } else if (blkdDims[5] == 16) {
709                         mkldnnFormat = memory::format::nCdhw16c;
710                     }
711                 } else if (order.size() == 5) {
712                     if (order[0] == 0 && order[1] == 1 && order[2] == 2 && order[3] == 3 && order[4] == 4) {
713                         mkldnnFormat = memory::format::ncdhw;
714                     } else if (order[0] == 0 && order[1] == 2 && order[2] == 3 && order[3] == 4 && order[4] == 1) {
715                         mkldnnFormat = memory::format::ndhwc;
716                     }
717                 }
718             }
719             break;
720         case CN:
721             mkldnnFormat = memory::format::blocked;
722             break;
723     }
724     if (mkldnnFormat == memory::format_undef)
725         THROW_IE_EXCEPTION << "Cannot detect the right memory format!";
726
727     bool notDefault = false;
728     size_t currentStride = 1;
729     for (size_t i = 0; i < order.size(); i++) {
730         if (offsetsToData[i] != 0) {
731             notDefault = true;
732             break;
733         }
734         if (strides[strides.size() - (1 +i)] != currentStride) {
735             notDefault = true;
736             break;
737         }
738         currentStride *= blkdDims[blkdDims.size() - (1 + i)];
739     }
740
741     bool blocked = false;
742     std::unordered_set<size_t> exist_order;
743     for (auto& ord : order) {
744         if (exist_order.find(ord) != exist_order.end()) {
745             blocked = true;
746             break;
747         }
748         exist_order.insert(ord);
749     }
750
751     if (notDefault && mkldnnFormat == memory::blocked && blocked)
752         THROW_IE_EXCEPTION << "Currently MKLDNNPlugin supports only packaged memory for unknown blocked format";
753
754     if (mkldnnFormat == memory::blocked) {
755         desc = MKLDNNMemoryDesc(realDims, data_type, memory::any);
756         desc.data.format = mkldnn_blocked;
757
758         auto& blk = desc.data.layout_desc.blocking;
759
760         blk.offset_padding = tDesc.getBlockingDesc().getOffsetPadding();
761
762         for (size_t i = 0; i < realDims.ndims(); i++) {
763             blk.block_dims[i] = 1;
764             blk.strides[1][i] = 1;
765             blk.padding_dims[i] = realDims[i];
766             blk.offset_padding_to_data[i] = offsetsToData[i];
767         }
768
769         int perm[TENSOR_MAX_DIMS] = {0};
770
771         for (size_t i = 0; i < realDims.ndims(); ++i) {
772             perm[i] = i;
773         }
774
775         blk.strides[0][perm[realDims.ndims() - 1]] = 1;
776
777         for (int d = 1; d < realDims.ndims(); ++d) {
778             const int prev_idx = perm[realDims.ndims() - d];
779             const int curr_idx = perm[realDims.ndims() - 1 - d];
780
781             blk.strides[0][curr_idx] = realDims[curr_idx] == 0 ? 1 : blk.strides[0][prev_idx] * (std::max)((ptrdiff_t)1, realDims[prev_idx]);
782         }
783     } else {
784         desc = MKLDNNMemoryDesc(realDims, data_type, mkldnnFormat);
785     }
786
787     desc.data.layout_desc.blocking.offset_padding = tDesc.getBlockingDesc().getOffsetPadding();
788     for (size_t i = 0; i < tDesc.getBlockingDesc().getOffsetPaddingToData().size() && i < TENSOR_MAX_DIMS; i++) {
789         desc.data.layout_desc.blocking.offset_padding_to_data[i] = static_cast<ptrdiff_t>(offsetsToData[i]);
790     }
791
792     if (notDefault) {
793         for (size_t i = 0; i < strides.size() && i < desc.data.ndims; i++) {
794             desc.data.layout_desc.blocking.strides[0][i] = static_cast<ptrdiff_t>(strides[order[i]]);
795         }
796     }
797 }
798
799 bool MKLDNNMemoryDesc::blocksExtended() const {
800     for (int i = 0; i < desc.data.ndims; i++) {
801         if (desc.data.dims[i] != desc.data.layout_desc.blocking.padding_dims[i])
802             return true;
803     }
804     return false;
805 }
806
807 }  // namespace MKLDNNPlugin