Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_gemm_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace mkldnn;
19
20 struct gemm_test_params {
21     struct {
22         size_t MB1_A;
23         size_t MB2_A;
24         size_t MB1_B;
25         size_t MB2_B;
26         size_t MB1_C;
27         size_t MB2_C;
28         size_t MB1_D;
29         size_t MB2_D;
30     } batches;
31
32     size_t M;
33     size_t N;
34     size_t K;
35
36     float alpha;
37     float beta;
38
39     bool transposeA;
40     bool transposeB;
41
42     size_t num_prim_desc;
43
44     MKLDNNPlugin::impl_desc_type selectedType;
45
46     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
47 };
48
49 template<typename data_t>
50 void ref_gemm(const std::vector<InferenceEngine::TBlob<data_t>> &src, InferenceEngine::TBlob<data_t> &dst,
51               gemm_test_params prm) {
52     const data_t *src0_data = src[0].readOnly();
53     const data_t *src1_data = src[1].readOnly();
54     const data_t *src2_data = src.size() == 3 ? src[2].readOnly() : dst.readOnly();
55     data_t *dst_data = dst.data();
56
57     size_t MB1 = prm.batches.MB1_D;
58     size_t MB2 = prm.batches.MB2_D;
59     size_t M  = prm.M;
60     size_t N  = prm.N;
61     size_t K  = prm.K;
62
63     for (int mb1 = 0; mb1 < MB1; mb1++) {
64         const data_t *a_data = src0_data;
65         const data_t *b_data = src1_data;
66         const data_t *c_data = src2_data;
67         data_t *d_data = dst_data;
68
69         for (int mb2 = 0; mb2 < MB2; mb2++) {
70             for (int i = 0; i < M; i++) {
71                 for (int j = 0; j < N; j++) {
72                     d_data[i * N + j] = src.size() == 3 ? prm.beta * c_data[i * N + j] : 0;
73
74                     for (int k = 0; k < K; k++) {
75                         size_t src0_off = prm.transposeA ? k * M + i : i * K + k;
76                         size_t src1_off = prm.transposeB ? j * K + k : k * N + j;
77                         d_data[i * N + j] += prm.alpha * a_data[src0_off] * b_data[src1_off];
78                     }
79                 }
80             }
81             a_data += prm.batches.MB2_A == MB2 ? M*K : 0;
82             b_data += prm.batches.MB2_B == MB2 ? K*N : 0;
83             c_data += prm.batches.MB2_C == MB2 ? M*N : 0;
84             d_data += M*N;
85         }
86
87         src0_data += prm.batches.MB1_A == MB1 ? prm.batches.MB2_A*M*K : 0;
88         src1_data += prm.batches.MB1_B == MB1 ? prm.batches.MB2_B*K*N : 0;
89         src2_data += prm.batches.MB1_C == MB1 ? prm.batches.MB2_C*M*N : 0;
90         dst_data += prm.batches.MB2_D*M*N;
91     }
92 }
93
94 class MKLDNNGraphGemmTests: public TestsCommon,
95                                      public WithParamInterface<gemm_test_params> {
96     std::string model_t = R"V0G0N(
97 <net name="gemmOnly" version="2" precision="FP32" batch="1">
98     <layers>
99         <layer name="in1" type="Input" precision="FP32" id="1">
100             <output>
101                 <port id="1">
102                     <dim>_MB1_A_</dim>
103                     <dim>_MB2_A_</dim>
104                     <dim>_M_</dim>
105                     <dim>_K_</dim>
106                 </port>
107             </output>
108         </layer>
109         <layer name="in2" type="Input" precision="FP32" id="2">
110             <output>
111                 <port id="1">
112                     <dim>_MB1_B_</dim>
113                     <dim>_MB2_B_</dim>
114                     <dim>_K_</dim>
115                     <dim>_N_</dim>
116                 </port>
117             </output>
118         </layer>
119         <layer name="in3" type="Input" precision="FP32" id="3">
120             <output>
121                 <port id="1">
122                     <dim>_MB1_C_</dim>
123                     <dim>_MB2_C_</dim>
124                     <dim>_M_</dim>
125                     <dim>_N_</dim>
126                 </port>
127             </output>
128         </layer>
129         <layer name="gemm" id="4" type="GEMM" precision="FP32">
130             <data alpha="_A_" beta="_B_" transpose_a="_TA_" transpose_b="_TB_"/>
131             <input>
132                 <port id="1">
133                     <dim>_MB1_A_</dim>
134                     <dim>_MB2_A_</dim>
135                     <dim>_M_</dim>
136                     <dim>_K_</dim>
137                 </port>
138                 <port id="2">
139                     <dim>_MB1_B_</dim>
140                     <dim>_MB2_B_</dim>
141                     <dim>_K_</dim>
142                     <dim>_N_</dim>
143                 </port>
144                 <port id="3">
145                     <dim>_MB1_C_</dim>
146                     <dim>_MB2_C_</dim>
147                     <dim>_M_</dim>
148                     <dim>_N_</dim>
149                 </port>
150             </input>
151             <output>
152                 <port id="4">
153                     <dim>_MB1_D_</dim>
154                     <dim>_MB2_D_</dim>
155                     <dim>_M_</dim>
156                     <dim>_N_</dim>
157                 </port>
158             </output>
159         </layer>
160     </layers>
161     <edges>
162         <edge from-layer="1" from-port="1" to-layer="4" to-port="1"/>
163         <edge from-layer="2" from-port="1" to-layer="4" to-port="2"/>
164         <edge from-layer="3" from-port="1" to-layer="4" to-port="3"/>
165     </edges>
166 </net>
167 )V0G0N";
168
169 protected:
170     std::string getModel(gemm_test_params p) {
171         std::string model = model_t;
172         std::string op;
173
174         REPLACE_WITH_NUM(model, "_MB1_A_", p.batches.MB1_A);
175         REPLACE_WITH_NUM(model, "_MB2_A_", p.batches.MB2_A);
176         REPLACE_WITH_NUM(model, "_MB1_B_", p.batches.MB1_B);
177         REPLACE_WITH_NUM(model, "_MB2_B_", p.batches.MB2_B);
178         REPLACE_WITH_NUM(model, "_MB1_C_", p.batches.MB1_C);
179         REPLACE_WITH_NUM(model, "_MB2_C_", p.batches.MB2_C);
180         REPLACE_WITH_NUM(model, "_MB1_D_", p.batches.MB1_D);
181         REPLACE_WITH_NUM(model, "_MB2_D_", p.batches.MB2_D);
182
183         REPLACE_WITH_NUM(model, "_M_", p.M);
184         REPLACE_WITH_NUM(model, "_N_", p.N);
185         REPLACE_WITH_NUM(model, "_K_", p.K);
186
187         REPLACE_WITH_NUM(model, "_A_", p.alpha);
188         REPLACE_WITH_NUM(model, "_B_", p.beta);
189         REPLACE_WITH_NUM(model, "_TA_", p.transposeA);
190         REPLACE_WITH_NUM(model, "_TB_", p.transposeB);
191
192         return model;
193     }
194
195     virtual void TearDown() {
196     }
197
198     virtual void SetUp() {
199         try {
200             TestsCommon::SetUp();
201             gemm_test_params p = ::testing::WithParamInterface<gemm_test_params>::GetParam();
202             std::string model = getModel(p);
203
204             InferenceEngine::CNNNetReader net_reader;
205             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
206
207             MKLDNNGraphTestClass graph;
208             graph.CreateGraph(net_reader.getNetwork());
209
210             auto& nodes = graph.getNodes();
211             for (int i = 0; i < nodes.size(); i++) {
212                 if (nodes[i]->getType() == MKLDNNPlugin::Gemm) {
213                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
214                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
215                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
216                     }
217                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
218                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
219                 }
220             }
221
222             InferenceEngine::SizeVector dims_src1 = {p.batches.MB1_A, p.batches.MB2_A, p.M, p.K};
223             InferenceEngine::SizeVector dims_src2 = {p.batches.MB1_B, p.batches.MB2_B, p.K, p.N};
224             InferenceEngine::SizeVector dims_src3 = {p.batches.MB1_C, p.batches.MB2_C, p.M, p.N};
225             InferenceEngine::SizeVector dims_dst  = {p.batches.MB1_D, p.batches.MB2_D, p.M, p.N};
226
227             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
228             src1->allocate();
229             InferenceEngine::TBlob<float>* srcPtr1 = dynamic_cast<InferenceEngine::TBlob<float>*>(src1.get());
230             if (srcPtr1 == nullptr)
231                 FAIL() << "Cannot cast blob to TBlob<float>.";
232             fill_data(src1->buffer(), src1->size());
233
234             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
235             src2->allocate();
236             InferenceEngine::TBlob<float>* srcPtr2 = dynamic_cast<InferenceEngine::TBlob<float>*>(src2.get());
237             if (srcPtr2 == nullptr)
238                 FAIL() << "Cannot cast blob to TBlob<float>.";
239             fill_data(src2->buffer(), src2->size());
240
241             InferenceEngine::Blob::Ptr src3 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src3);
242             src3->allocate();
243             InferenceEngine::TBlob<float>* srcPtr3 = dynamic_cast<InferenceEngine::TBlob<float>*>(src3.get());
244             if (srcPtr3 == nullptr)
245                 FAIL() << "Cannot cast blob to TBlob<float>.";
246             fill_data(src3->buffer(), src3->size());
247
248             InferenceEngine::BlobMap srcs;
249             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
250             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
251             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src3));
252
253             InferenceEngine::OutputsDataMap out;
254             out = net_reader.getNetwork().getOutputsInfo();
255             InferenceEngine::BlobMap outputBlobs;
256
257             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
258
259             InferenceEngine::TBlob<float>::Ptr output;
260             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
261             output->allocate();
262             outputBlobs[item.first] = output;
263
264             graph.Infer(srcs, outputBlobs);
265
266             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
267             dst_ref.allocate();
268
269             std::vector<InferenceEngine::TBlob<float>> src_vec = {*srcPtr1, *srcPtr2, *srcPtr3};
270
271             ref_gemm(src_vec, dst_ref, p);
272
273             compare(*output, dst_ref);
274         } catch (const InferenceEngine::details::InferenceEngineException &e) {
275             FAIL() << e.what();
276         }
277     }
278 };
279
280 TEST_P(MKLDNNGraphGemmTests, TestsGemm) {}
281
282 INSTANTIATE_TEST_CASE_P(
283         TestsGemm, MKLDNNGraphGemmTests,
284         ::testing::Values(
285                 gemm_test_params{{2, 1, 2, 1, 2, 1, 2, 1}, 3, 3, 2, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
286                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
287                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
288                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
289                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
290                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
291                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
292                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
293                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
294                         }
295                 } },
296                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 8, 5, 4, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
297                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
298                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
299                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
300                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
301                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
302                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
303                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
304                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
305                         }
306                 } },
307                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 16, 10, 12, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
308                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
309                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
310                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
311                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
312                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
313                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
314                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
315                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
316                         }
317                 } },
318                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 11, 10, 20, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
319                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
320                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
321                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
322                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
323                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
324                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
325                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
326                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
327                         }
328                 } },
329                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 5, 13, 2, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
330                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
331                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
332                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
333                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
334                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
335                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
336                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
337                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
338                         }
339                 } },
340                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 5, 15, 10, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
341                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
342                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
343                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
344                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
345                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
346                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
347                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
348                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
349                         }
350                 } },
351                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 5, 6, 7, 2, 0, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
352                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 5, 6, 7, 0, 2, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
353                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 3, 7, 4, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
354                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 3, 4, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
355                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
356                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 3, 7, 4, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
357                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 3, 4, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
358                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
359                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 3, 7, 4, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
360                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 3, 4, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
361                 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
362                 gemm_test_params{{1, 3, 2, 3, 2, 3, 2, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
363                 gemm_test_params{{1, 3, 2, 3, 1, 3, 2, 3}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
364                 gemm_test_params{{2, 3, 1, 3, 1, 3, 2, 3}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
365                 gemm_test_params{{5, 3, 5, 1, 5, 3, 5, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
366                 gemm_test_params{{5, 3, 5, 1, 5, 1, 5, 3}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
367                 gemm_test_params{{5, 1, 5, 1, 5, 3, 5, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
368                 gemm_test_params{{1, 1, 5, 3, 5, 3, 5, 3}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
369                 gemm_test_params{{1, 1, 1, 1, 5, 3, 5, 3}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
370                 gemm_test_params{{5, 4, 1, 1, 1, 1, 5, 4}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any}
371         ));
372
373 class MKLDNNGraphDynBatchGemmTests: public MKLDNNGraphGemmTests {
374 protected:
375     virtual void SetUp() {
376         try {
377             TestsCommon::SetUp();
378             gemm_test_params p = ::testing::WithParamInterface<gemm_test_params>::GetParam();
379             std::string model = getModel(p);
380             size_t MB = p.batches.MB1_D;
381             if (MB < 2)
382                 MB = 2;
383
384             InferenceEngine::CNNNetReader net_reader;
385             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
386             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
387             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
388             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
389             InferenceEngine::ResponseDesc resp;
390             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
391             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
392
393             MKLDNNGraphTestClass graph;
394             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
395             graph.CreateGraph(net_reader.getNetwork());
396
397             InferenceEngine::SizeVector dims_src1 = {MB, p.batches.MB2_A, p.M, p.K};
398             InferenceEngine::SizeVector dims_src2 = {MB, p.batches.MB2_B, p.K, p.N};
399             InferenceEngine::SizeVector dims_src3 = {MB, p.batches.MB2_C, p.M, p.N};
400             InferenceEngine::SizeVector dims_dst  = {MB, p.batches.MB2_D, p.M, p.N};
401
402             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
403             src1->allocate();
404             InferenceEngine::TBlob<float>* srcPtr1 = dynamic_cast<InferenceEngine::TBlob<float>*>(src1.get());
405             if (srcPtr1 == nullptr)
406                 FAIL() << "Cannot cast blob to TBlob<float>.";
407             fill_data(src1->buffer(), src1->size());
408
409             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
410             src2->allocate();
411             InferenceEngine::TBlob<float>* srcPtr2 = dynamic_cast<InferenceEngine::TBlob<float>*>(src2.get());
412             if (srcPtr2 == nullptr)
413                 FAIL() << "Cannot cast blob to TBlob<float>.";
414             fill_data(src2->buffer(), src2->size());
415
416             InferenceEngine::Blob::Ptr src3 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src3);
417             src3->allocate();
418             InferenceEngine::TBlob<float>* srcPtr3 = dynamic_cast<InferenceEngine::TBlob<float>*>(src3.get());
419             if (srcPtr3 == nullptr)
420                 FAIL() << "Cannot cast blob to TBlob<float>.";
421             fill_data(src3->buffer(), src3->size());
422
423             InferenceEngine::BlobMap srcs;
424             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
425             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
426             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src3));
427
428             InferenceEngine::OutputsDataMap out;
429             out = net_reader.getNetwork().getOutputsInfo();
430             InferenceEngine::BlobMap outputBlobs;
431
432             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
433
434             InferenceEngine::TBlob<float>::Ptr output;
435             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
436             output->allocate();
437             outputBlobs[item.first] = output;
438
439             auto check = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
440                 return node->getType() == MKLDNNPlugin::Gemm;
441             };
442
443             graph.checkDynBatch(srcs, outputBlobs, MB, MB, check);
444             graph.checkDynBatch(srcs, outputBlobs, 1, MB, check);
445         } catch (const InferenceEngine::details::InferenceEngineException &e) {
446             FAIL() << e.what();
447         }
448     }
449 };
450
451 TEST_P(MKLDNNGraphDynBatchGemmTests, TestsDynBatchGemm) {}
452
453 INSTANTIATE_TEST_CASE_P(
454         TestsDynBatchGemm, MKLDNNGraphDynBatchGemmTests,
455         ::testing::Values(
456                 gemm_test_params{{1, 3, 1, 3, 1, 3, 1, 3}, 3, 3, 3, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
457                 gemm_test_params{{1, 3, 1, 1, 1, 3, 1, 3}, 16, 15, 12, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any}
458 ));
459
460 class MKLDNNGraphSingleBatchDimGemmTests: public TestsCommon,
461                                      public WithParamInterface<gemm_test_params> {
462     std::string model_t = R"V0G0N(
463 <net name="gemmOnly" version="2" precision="FP32" batch="1">
464     <layers>
465         <layer name="in1" type="Input" precision="FP32" id="1">
466             <output>
467                 <port id="1">
468                     <dim>_MB_A_</dim>
469                     <dim>_M_</dim>
470                     <dim>_K_</dim>
471                 </port>
472             </output>
473         </layer>
474         <layer name="in2" type="Input" precision="FP32" id="2">
475             <output>
476                 <port id="1">
477                     <dim>_MB_B_</dim>
478                     <dim>_K_</dim>
479                     <dim>_N_</dim>
480                 </port>
481             </output>
482         </layer>
483         <layer name="gemm" id="3" type="GEMM" precision="FP32">
484             <data alpha="_A_" beta="_B_" transpose_a="_TA_" transpose_b="_TB_"/>
485             <input>
486                 <port id="1">
487                     <dim>_MB_A_</dim>
488                     <dim>_M_</dim>
489                     <dim>_K_</dim>
490                 </port>
491                 <port id="2">
492                     <dim>_MB_B_</dim>
493                     <dim>_K_</dim>
494                     <dim>_N_</dim>
495                 </port>
496             </input>
497             <output>
498                 <port id="3">
499                     <dim>_MB_D_</dim>
500                     <dim>_M_</dim>
501                     <dim>_N_</dim>
502                 </port>
503             </output>
504         </layer>
505     </layers>
506     <edges>
507         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
508         <edge from-layer="2" from-port="1" to-layer="3" to-port="2"/>
509     </edges>
510 </net>
511 )V0G0N";
512
513 protected:
514     std::string getModel(gemm_test_params p) {
515         std::string model = model_t;
516         std::string op;
517
518         REPLACE_WITH_NUM(model, "_MB_A_", p.batches.MB2_A);
519         REPLACE_WITH_NUM(model, "_MB_B_", p.batches.MB2_B);
520         REPLACE_WITH_NUM(model, "_MB_D_", p.batches.MB2_D);
521
522         REPLACE_WITH_NUM(model, "_M_", p.M);
523         REPLACE_WITH_NUM(model, "_N_", p.N);
524         REPLACE_WITH_NUM(model, "_K_", p.K);
525
526         REPLACE_WITH_NUM(model, "_A_", p.alpha);
527         REPLACE_WITH_NUM(model, "_B_", p.beta);
528         REPLACE_WITH_NUM(model, "_TA_", p.transposeA);
529         REPLACE_WITH_NUM(model, "_TB_", p.transposeB);
530
531         return model;
532     }
533
534     virtual void TearDown() {
535     }
536
537     virtual void SetUp() {
538         try {
539             TestsCommon::SetUp();
540             gemm_test_params p = ::testing::WithParamInterface<gemm_test_params>::GetParam();
541             std::string model = getModel(p);
542
543             InferenceEngine::CNNNetReader net_reader;
544             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
545
546             MKLDNNGraphTestClass graph;
547             graph.CreateGraph(net_reader.getNetwork());
548
549             auto& nodes = graph.getNodes();
550             for (int i = 0; i < nodes.size(); i++) {
551                 if (nodes[i]->getType() == MKLDNNPlugin::Gemm) {
552                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
553                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
554                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
555                     }
556                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
557                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
558                 }
559             }
560
561             InferenceEngine::SizeVector dims_src1 = {p.batches.MB2_A, p.M, p.K};
562             InferenceEngine::SizeVector dims_src2 = {p.batches.MB2_B, p.K, p.N};
563             InferenceEngine::SizeVector dims_dst  = {p.batches.MB2_D, p.M, p.N};
564
565             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::CHW, dims_src1);
566             src1->allocate();
567             InferenceEngine::TBlob<float>* srcPtr1 = dynamic_cast<InferenceEngine::TBlob<float>*>(src1.get());
568             if (srcPtr1 == nullptr)
569                 FAIL() << "Cannot cast blob to TBlob<float>.";
570             fill_data(src1->buffer(), src1->size());
571
572             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::CHW, dims_src2);
573             src2->allocate();
574             InferenceEngine::TBlob<float>* srcPtr2 = dynamic_cast<InferenceEngine::TBlob<float>*>(src2.get());
575             if (srcPtr2 == nullptr)
576                 FAIL() << "Cannot cast blob to TBlob<float>.";
577             fill_data(src2->buffer(), src2->size());
578
579             InferenceEngine::BlobMap srcs;
580             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
581             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
582
583             InferenceEngine::OutputsDataMap out;
584             out = net_reader.getNetwork().getOutputsInfo();
585             InferenceEngine::BlobMap outputBlobs;
586
587             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
588
589             InferenceEngine::TBlob<float>::Ptr output;
590             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
591             output->allocate();
592             outputBlobs[item.first] = output;
593
594             graph.Infer(srcs, outputBlobs);
595
596             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
597             dst_ref.allocate();
598
599             std::vector<InferenceEngine::TBlob<float>> src_vec = {*srcPtr1, *srcPtr2};
600
601             ref_gemm(src_vec, dst_ref, p);
602
603             compare(*output, dst_ref);
604         } catch (const InferenceEngine::details::InferenceEngineException &e) {
605             FAIL() << e.what();
606         }
607     }
608 };
609
610 TEST_P(MKLDNNGraphSingleBatchDimGemmTests, TestsGemm) {}
611
612 INSTANTIATE_TEST_CASE_P(
613         TestsGemm, MKLDNNGraphSingleBatchDimGemmTests,
614         ::testing::Values(
615                 gemm_test_params{{1, 1, 1, 1, 1, 1, 1, 1}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
616                 gemm_test_params{{1, 3, 1, 3, 1, 1, 1, 3}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
617                 gemm_test_params{{1, 3, 1, 1, 1, 1, 1, 3}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
618                 gemm_test_params{{1, 1, 1, 1, 1, 1, 1, 1}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
619                 gemm_test_params{{1, 3, 1, 3, 1, 1, 1, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
620                 gemm_test_params{{1, 3, 1, 1, 1, 1, 1, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
621                 gemm_test_params{{1, 1, 1, 1, 1, 1, 1, 1}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
622                 gemm_test_params{{1, 3, 1, 3, 1, 1, 1, 3}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
623                 gemm_test_params{{1, 3, 1, 1, 1, 1, 1, 3}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
624                 gemm_test_params{{1, 1, 1, 1, 1, 1, 1, 1}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
625                 gemm_test_params{{1, 3, 1, 3, 1, 1, 1, 3}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
626                 gemm_test_params{{1, 3, 1, 1, 1, 1, 1, 3}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any}
627         ));