Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / test_graph.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <mkldnn_plugin/mkldnn_graph.h>
8 #include <mkldnn_plugin/mkldnn_memory.h>
9 #include <mkldnn_plugin/mkldnn_extension_utils.h>
10 #include <mkldnn_plugin/mkldnn_graph_optimizer.h>
11 #include <mkldnn_plugin/nodes/mkldnn_input_node.h>
12 #include <functional>
13
14 #define GARB_VAL(x) ((x + 100.0f + sin(x)) / (x + 150.f))
15
16 class MKLDNNGraphTestClass: public MKLDNNPlugin::MKLDNNGraph {
17 public:
18     enum class CheckDynBatchType {
19         Both,
20         Parent,
21         Child
22     };
23     MKLDNNGraphTestClass(): MKLDNNPlugin::MKLDNNGraph() {}
24     virtual ~MKLDNNGraphTestClass() = default;
25
26     static std::string getStrPrimitiveDescriptorType(MKLDNNPlugin::impl_desc_type type) {
27         std::string str_type;
28
29         auto add_type = [&](std::string t) {
30             if (!str_type.empty() && t.c_str()[0] != '_')
31                 str_type += "_";
32             str_type += t;
33         };
34
35 #define SEARCH_TYPE(_type)                                                                      \
36     if ((type & MKLDNNPlugin::impl_desc_type::_type) == MKLDNNPlugin::impl_desc_type::_type)    \
37         add_type(#_type)
38
39         SEARCH_TYPE(undef);
40         SEARCH_TYPE(reorder);
41         SEARCH_TYPE(jit);
42         SEARCH_TYPE(gemm);
43         SEARCH_TYPE(ref);
44
45         SEARCH_TYPE(avx512);
46         SEARCH_TYPE(avx2);
47         SEARCH_TYPE(sse42);
48         SEARCH_TYPE(blas);
49         SEARCH_TYPE(any);
50
51         SEARCH_TYPE(winograd);
52         SEARCH_TYPE(_dw);
53         SEARCH_TYPE(_1x1);
54
55         if (type == MKLDNNPlugin::impl_desc_type::unknown)
56             str_type = "unknown";
57         else if (str_type.empty())
58             str_type = "undef";
59         return str_type;
60     }
61
62     void PushInputData(const std::string& name, const InferenceEngine::Blob::Ptr &in, int batch) {
63         if (!IsReady()) THROW_IE_EXCEPTION<< "Wrong state. Topology not ready.";
64
65         auto input = inputNodes.find(name);
66         if (input != inputNodes.end()) {
67             MKLDNNPlugin::MKLDNNDims outDims = input->second->getChildEdgeAt(0)->getDims();
68             if (batch < 1)
69                 batch = outDims[0];
70
71             const void *ext_data_ptr = in->cbuffer();
72             void *inter_data_ptr = input->second->getChildEdgeAt(0)->getMemory().GetData();
73
74             if (ext_data_ptr != inter_data_ptr)
75                 input->second->getChildEdgeAt(0)->getMemory().SetData(MKLDNNPlugin::MKLDNNExtensionUtils::IEPrecisionToDataType(in->getTensorDesc().getPrecision()),
76                                                                       MKLDNNPlugin::MKLDNNMemory::GetPlainFormat(outDims), ext_data_ptr, in->byteSize() / outDims[0] * batch, false);
77
78             // todo: make sure 'name' exists in this map...
79             if (_meanImages.find(name) != _meanImages.end()) {
80                 if (in->getTensorDesc().getPrecision() == InferenceEngine::Precision::FP32) {
81                     _meanImages[name].Subtract(outDims, reinterpret_cast<float *>(inter_data_ptr), in->getTensorDesc().getLayout());
82                 } else {
83                     THROW_IE_EXCEPTION << "Mean image of type " << in->getTensorDesc().getPrecision().name() << " is unsupported";
84                 }
85             }
86         } else {
87             THROW_IE_EXCEPTION << "Input blob for infer '" << name << "' doesn't correspond to input in network";
88         }
89     }
90
91     void Infer(const InferenceEngine::BlobMap& inputs, InferenceEngine::BlobMap& result, int batch = -1) {
92         try {
93             // need to retain converted blobs until infer finish
94             std::vector<InferenceEngine::Blob::Ptr> convertedInputs;
95             for (auto input : inputs) {
96                 switch (input.second->precision()) {
97                     case InferenceEngine::Precision::FP32: {
98                         InferenceEngine::TBlob<float> *in_f = nullptr;
99                         in_f = dynamic_cast<InferenceEngine::TBlob<float> *>(input.second.get());
100                         if (in_f == nullptr) {
101                             FAIL() << "Input data precision not supported. Expected float.";
102                         }
103
104                         if (in_f->readOnly() == nullptr) {
105                             THROW_IE_EXCEPTION << "Input data was not allocated.";
106                         }
107                     }
108                     break;
109                     case InferenceEngine::Precision::I32: {
110                         InferenceEngine::TBlob<int32_t> *in_f = nullptr;
111                         in_f = dynamic_cast<InferenceEngine::TBlob<int32_t> *>(input.second.get());
112                         if (in_f == nullptr) {
113                             FAIL() << "Input data precision not supported. Expected float.";
114                         }
115
116                         if (in_f->readOnly() == nullptr) {
117                             THROW_IE_EXCEPTION << "Input data was not allocated.";
118                         }
119                     }
120                     break;
121                     case InferenceEngine::Precision::U16: {
122                         InferenceEngine::TBlob<uint16_t> *in_f = nullptr;
123                         in_f = dynamic_cast<InferenceEngine::TBlob<uint16_t> *>(input.second.get());
124                         if (in_f == nullptr) {
125                             FAIL() << "Input data precision not supported. Expected float.";
126                         }
127
128                         if (in_f->readOnly() == nullptr) {
129                             THROW_IE_EXCEPTION << "Input data was not allocated.";
130                         }
131                     }
132                     break;
133                     case InferenceEngine::Precision::I16: {
134                         InferenceEngine::TBlob<int16_t> *in_f = nullptr;
135                         in_f = dynamic_cast<InferenceEngine::TBlob<int16_t> *>(input.second.get());
136                         if (in_f == nullptr) {
137                             FAIL() << "Input data precision not supported. Expected float.";
138                         }
139
140                         if (in_f->readOnly() == nullptr) {
141                             THROW_IE_EXCEPTION << "Input data was not allocated.";
142                         }
143                     }
144                     break;
145                     case InferenceEngine::Precision::U8: {
146                         InferenceEngine::TBlob<uint8_t> *in_f = nullptr;
147                         in_f = dynamic_cast<InferenceEngine::TBlob<uint8_t> *>(input.second.get());
148                         if (in_f == nullptr) {
149                             FAIL() << "Input data precision not supported. Expected float.";
150                         }
151
152                         if (in_f->readOnly() == nullptr) {
153                             THROW_IE_EXCEPTION << "Input data was not allocated.";
154                         }
155                     }
156                     break;
157                     case InferenceEngine::Precision::I8: {
158                         InferenceEngine::TBlob<int8_t> *in_f = nullptr;
159                         in_f = dynamic_cast<InferenceEngine::TBlob<int8_t> *>(input.second.get());
160                         if (in_f == nullptr) {
161                             FAIL() << "Input data precision not supported. Expected float.";
162                         }
163
164                         if (in_f->readOnly() == nullptr) {
165                             THROW_IE_EXCEPTION << "Input data was not allocated.";
166                         }
167                     }
168                     break;
169                     default:
170                         THROW_IE_EXCEPTION << "Unsupported input precision " << input.second->precision();
171                 }
172
173                 PushInputData(input.first, input.second, batch);
174             }
175             MKLDNNPlugin::MKLDNNGraph::Infer(batch);
176         } catch (const std::exception &e) {
177             FAIL() << e.what();
178         }
179
180         PullOutputData(result);
181     }
182
183     std::vector<MKLDNNPlugin::MKLDNNNodePtr>& getNodes() {
184         return graphNodes;
185     }
186
187     void CreateGraph(InferenceEngine::ICNNNetwork &network, const MKLDNNPlugin::MKLDNNExtensionManager::Ptr& extMgr) {
188         MKLDNNGraph::CreateGraph(network, extMgr);
189     }
190
191     void CreateGraph(InferenceEngine::ICNNNetwork &network) {
192         MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr;
193         CreateGraph(network, extMgr);
194     }
195
196     void checkDynBatch(InferenceEngine::BlobMap& srcs, InferenceEngine::BlobMap& outputBlobs, int batch, size_t MB,
197                        const std::function<bool (const MKLDNNPlugin::MKLDNNNodePtr&)>& comp, CheckDynBatchType type = CheckDynBatchType::Both) {
198         for (auto &node : getNodes()) {
199             if (comp(node)) {
200                 auto inputBlob = node->getParentEdgeAt(0)->getBlob();
201                 auto *data = inputBlob->buffer().as<float *>();
202                 size_t dataSize = inputBlob->getTensorDesc().getBlockingDesc().getStrides()[0] * MB;
203                 for (size_t j = 0; j < dataSize; j++) {
204                     data[j] = GARB_VAL(j);
205                 }
206
207                 auto outputBlob = node->getChildEdgeAt(0)->getBlob();
208                 data = outputBlob->buffer().as<float *>();
209                 dataSize = outputBlob->getTensorDesc().getBlockingDesc().getStrides()[0] * MB;
210                 for (size_t j = 0; j < dataSize; j++) {
211                     data[j] = GARB_VAL(j);
212                 }
213             }
214         }
215
216         Infer(srcs, outputBlobs, batch);
217
218         for (auto &node : getNodes()) {
219             if (comp(node)) {
220                 auto inputBlob = node->getParentEdgeAt(0)->getBlob();
221                 auto *data = inputBlob->buffer().as<float *>();
222                 auto inputNoBatchSize = inputBlob->getTensorDesc().getBlockingDesc().getStrides()[0];
223                 for (size_t i = 0; i < batch; i++) {
224                     for (size_t j = 0; j < inputNoBatchSize; j++) {
225                         ASSERT_NE(data[i*inputNoBatchSize + j], GARB_VAL(i*inputNoBatchSize + j));
226                     }
227                 }
228
229                 if (type == CheckDynBatchType::Both || type == CheckDynBatchType::Parent) {
230                     for (size_t i = static_cast<size_t>(batch); i < MB; i++) {
231                         for (size_t j = 0; j < inputNoBatchSize; j++) {
232                             ASSERT_NEAR(data[i * inputNoBatchSize + j],
233                                         GARB_VAL(i * inputNoBatchSize + j), 0.001f);
234                         }
235                     }
236                 }
237
238                 auto outputBlob = node->getChildEdgeAt(0)->getBlob();
239                 data = outputBlob->buffer().as<float *>();
240                 auto outputNoBatchSize = outputBlob->getTensorDesc().getBlockingDesc().getStrides()[0];
241                 for (size_t i = 0; i < batch; i++) {
242                     for (size_t j = 0; j < outputNoBatchSize; j++) {
243                         ASSERT_NE(data[i*outputNoBatchSize + j], GARB_VAL(i*outputNoBatchSize + j));
244                     }
245                 }
246                 if (type == CheckDynBatchType::Both || type == CheckDynBatchType::Child) {
247                     for (size_t i = static_cast<size_t>(batch); i < MB; i++) {
248                         for (size_t j = 0; j < outputNoBatchSize; j++) {
249                             ASSERT_NEAR(data[i * outputNoBatchSize + j],
250                                         GARB_VAL(i * outputNoBatchSize + j), 0.001f);
251                         }
252                     }
253                 }
254             }
255         }
256     }
257 };