Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_extension_utils.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_extension_utils.h"
6 #include <limits>
7 #include <vector>
8
9 using namespace mkldnn;
10 using namespace MKLDNNPlugin;
11
12 uint8_t MKLDNNExtensionUtils::sizeOfDataType(mkldnn::memory::data_type dataType) {
13     switch (dataType) {
14     case mkldnn::memory::data_type::f32:
15         return 4;
16     case mkldnn::memory::data_type::s32:
17         return 4;
18     case mkldnn::memory::data_type::s16:
19         return 2;
20     case mkldnn::memory::data_type::s8:
21         return 1;
22     case mkldnn::memory::data_type::u8:
23         return 1;
24     case mkldnn::memory::data_type::bin:
25         return 1;
26     case mkldnn::memory::data_type::data_undef:
27         return 0;
28
29     default:
30         THROW_IE_EXCEPTION << "Unsupported data type.";
31     }
32 }
33
34 memory::data_type MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision prec) {
35     switch (prec) {
36         case InferenceEngine::Precision::FP32:
37             return memory::f32;
38         case InferenceEngine::Precision::I32:
39             return memory::s32;
40         case InferenceEngine::Precision::I16:
41             return memory::s16;
42         case InferenceEngine::Precision::I8:
43             return memory::s8;
44         case InferenceEngine::Precision::U8:
45             return memory::u8;
46         case InferenceEngine::Precision::BIN:
47             return memory::bin;
48
49         default: {
50             THROW_IE_EXCEPTION << "The plugin does not support " << prec.name();
51         }
52     }
53 }
54
55 InferenceEngine::Precision MKLDNNExtensionUtils::DataTypeToIEPrecision(memory::data_type dataType) {
56     switch (dataType) {
57         case memory::f32:
58             return InferenceEngine::Precision(InferenceEngine::Precision::FP32);
59         case memory::s32:
60             return InferenceEngine::Precision::I32;
61         case memory::s16:
62             return InferenceEngine::Precision::I16;
63         case memory::s8:
64             return InferenceEngine::Precision::I8;
65         case memory::u8:
66             return InferenceEngine::Precision::U8;
67         case memory::bin:
68             return InferenceEngine::Precision::BIN;
69
70         default: {
71             THROW_IE_EXCEPTION << "Unsupported data type.";
72         }
73     }
74 }
75
76 InferenceEngine::TensorDesc MKLDNNExtensionUtils::getUninitTensorDesc(const InferenceEngine::TensorDesc &desc) {
77     std::vector<size_t> notInitArr;
78     std::vector<size_t> zeroArr;
79     for (size_t i = 0; i < desc.getBlockingDesc().getBlockDims().size(); i++) {
80         notInitArr.push_back(std::numeric_limits<size_t>::max());
81         zeroArr.push_back(0);
82     }
83     // MKLDNN doesn't support offset_padding_to_data[i] != 0 (assert(src_d_blk.offset_padding_to_data[d] == 0);)
84     return desc.getLayout() == InferenceEngine::Layout::ANY ? desc :
85            InferenceEngine::TensorDesc(desc.getPrecision(), desc.getDims(),
86                                        {desc.getBlockingDesc().getBlockDims(), desc.getBlockingDesc().getOrder(),
87                                         std::numeric_limits<size_t>::max(), zeroArr, notInitArr});
88 }
89
90 bool MKLDNNExtensionUtils::initTensorsAreEqual(InferenceEngine::TensorDesc desc1, InferenceEngine::TensorDesc desc2) {
91     if (desc1.getDims() != desc2.getDims() || desc1.getPrecision() != desc2.getPrecision())
92         return false;
93     if (desc1.getLayout() == InferenceEngine::Layout::ANY || desc2.getLayout() == InferenceEngine::Layout::ANY)
94         return true;
95     bool batch1 = desc1.getDims()[0] == 1;
96     const auto& in1Block = desc1.getBlockingDesc();
97     const auto& in2Block = desc2.getBlockingDesc();
98     size_t uninitNum = std::numeric_limits<size_t>::max();
99     if (in1Block.getBlockDims().size() != in2Block.getBlockDims().size())
100         return false;
101     for (size_t i = 0; i < in1Block.getBlockDims().size(); i++) {
102         if (in1Block.getBlockDims()[i] != in2Block.getBlockDims()[i] &&
103                 in1Block.getBlockDims()[i] != uninitNum && in2Block.getBlockDims()[i] != uninitNum)
104             return false;
105         if (in1Block.getOffsetPaddingToData()[i] != in2Block.getOffsetPaddingToData()[i] &&
106                 in1Block.getOffsetPaddingToData()[i] != uninitNum && in2Block.getOffsetPaddingToData()[i] != uninitNum)
107             return false;
108         if (i >= batch1 && in1Block.getStrides()[i] != in2Block.getStrides()[i] &&
109                 in1Block.getStrides()[i] != uninitNum && in2Block.getStrides()[i] != uninitNum)
110             return false;
111         if (in1Block.getOrder()[i] != in2Block.getOrder()[i] &&
112                 in1Block.getOrder()[i] != uninitNum && in2Block.getOrder()[i] != uninitNum)
113             return false;
114     }
115     return !(in1Block.getOffsetPadding() != in2Block.getOffsetPadding() &&
116         in1Block.getOffsetPadding() != uninitNum && in2Block.getOffsetPadding() != uninitNum);
117 }